Unverified Commit e6a15c1a authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Enable `node_classification.py` to run on GPU environment (#6490)

parent 2a92dfca
...@@ -47,7 +47,7 @@ import torch ...@@ -47,7 +47,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics.functional as MF import torchmetrics.functional as MF
import tqdm from tqdm import tqdm
def create_dataloader( def create_dataloader(
...@@ -145,6 +145,16 @@ def create_dataloader( ...@@ -145,6 +145,16 @@ def create_dataloader(
############################################################################ ############################################################################
# [Step-5]: # [Step-5]:
# self.copy_to()
# [Input]:
# 'device': The device to copy the data to.
# [Output]:
# A CopyTo object to copy the data to the specified device.
############################################################################
datapipe = datapipe.copy_to(device=args.device)
############################################################################
# [Step-6]:
# gb.MultiProcessDataLoader() # gb.MultiProcessDataLoader()
# [Input]: # [Input]:
# 'datapipe': The datapipe object to be used for data loading. # 'datapipe': The datapipe object to be used for data loading.
...@@ -191,10 +201,15 @@ class SAGE(nn.Module): ...@@ -191,10 +201,15 @@ class SAGE(nn.Module):
hidden_x = self.dropout(hidden_x) hidden_x = self.dropout(hidden_x)
return hidden_x return hidden_x
def inference(self, graph, features, dataloader): def inference(self, graph, features, dataloader, device):
"""Conduct layer-wise inference to get all the node embeddings.""" """Conduct layer-wise inference to get all the node embeddings."""
feature = features.read("node", None, "feat") feature = features.read("node", None, "feat")
buffer_device = torch.device("cpu")
# Enable pin_memory for faster CPU to GPU data transfer if the
# model is running on a GPU.
pin_memory = buffer_device != device
for layer_idx, layer in enumerate(self.layers): for layer_idx, layer in enumerate(self.layers):
is_last_layer = layer_idx == len(self.layers) - 1 is_last_layer = layer_idx == len(self.layers) - 1
...@@ -202,16 +217,21 @@ class SAGE(nn.Module): ...@@ -202,16 +217,21 @@ class SAGE(nn.Module):
graph.total_num_nodes, graph.total_num_nodes,
self.out_size if is_last_layer else self.hidden_size, self.out_size if is_last_layer else self.hidden_size,
dtype=torch.float64, dtype=torch.float64,
device=buffer_device,
pin_memory=pin_memory,
) )
feature = feature.to(device)
for step, data in tqdm.tqdm(enumerate(dataloader)): for step, data in tqdm(enumerate(dataloader)):
x = feature[data.input_nodes] x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1 hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer: if not is_last_layer:
hidden_x = F.relu(hidden_x) hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x) hidden_x = self.dropout(hidden_x)
# By design, our output nodes are contiguous. # By design, our output nodes are contiguous.
y[data.output_nodes[0] : data.output_nodes[-1] + 1] = hidden_x y[
data.output_nodes[0] : data.output_nodes[-1] + 1
] = hidden_x.to(buffer_device)
feature = y feature = y
return y return y
...@@ -225,7 +245,7 @@ def layerwise_infer( ...@@ -225,7 +245,7 @@ def layerwise_infer(
dataloader = create_dataloader( dataloader = create_dataloader(
args, graph, features, all_nodes_set, job="infer" args, graph, features, all_nodes_set, job="infer"
) )
pred = model.inference(graph, features, dataloader) pred = model.inference(graph, features, dataloader, args.device)
pred = pred[test_set._items[0]] pred = pred[test_set._items[0]]
label = test_set._items[1].to(pred.device) label = test_set._items[1].to(pred.device)
...@@ -246,7 +266,7 @@ def evaluate(args, model, graph, features, itemset, num_classes): ...@@ -246,7 +266,7 @@ def evaluate(args, model, graph, features, itemset, num_classes):
args, graph, features, itemset, job="evaluate" args, graph, features, itemset, job="evaluate"
) )
for step, data in tqdm.tqdm(enumerate(dataloader)): for step, data in tqdm(enumerate(dataloader)):
x = data.node_features["feat"] x = data.node_features["feat"]
y.append(data.labels) y.append(data.labels)
y_hats.append(model(data.blocks, x)) y_hats.append(model(data.blocks, x))
...@@ -265,10 +285,10 @@ def train(args, graph, features, train_set, valid_set, num_classes, model): ...@@ -265,10 +285,10 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
args, graph, features, train_set, job="train" args, graph, features, train_set, job="train"
) )
for epoch in tqdm.trange(args.epochs): for epoch in range(args.epochs):
model.train() model.train()
total_loss = 0 total_loss = 0
for step, data in tqdm.tqdm(enumerate(dataloader)): for step, data in tqdm(enumerate(dataloader)):
# The input features from the source nodes in the first layer's # The input features from the source nodes in the first layer's
# computation graph. # computation graph.
x = data.node_features["feat"] x = data.node_features["feat"]
...@@ -326,14 +346,28 @@ def parse_args(): ...@@ -326,14 +346,28 @@ def parse_args():
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)" help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
" identical with the number of layers in your model. Default: 15,10,5", " identical with the number of layers in your model. Default: 15,10,5",
) )
parser.add_argument(
"--device",
default="cpu",
choices=["cpu", "cuda"],
help="Train device: 'cpu' for CPU, 'cuda' for GPU.",
)
return parser.parse_args() return parser.parse_args()
def main(args): def main(args):
if not torch.cuda.is_available():
args.device = "cpu"
print(f"Training in {args.device} mode.")
args.device = torch.device(args.device)
# Load and preprocess dataset. # Load and preprocess dataset.
print("Loading data...")
dataset = gb.BuiltinDataset("ogbn-products").load() dataset = gb.BuiltinDataset("ogbn-products").load()
graph = dataset.graph graph = dataset.graph
# Currently the neighbor-sampling process can only be done on the CPU,
# therefore there is no need to copy the graph to the GPU.
features = dataset.feature features = dataset.feature
train_set = dataset.tasks[0].train_set train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set valid_set = dataset.tasks[0].validation_set
...@@ -348,6 +382,8 @@ def main(args): ...@@ -348,6 +382,8 @@ def main(args):
out_size = num_classes out_size = num_classes
model = SAGE(in_size, hidden_size, out_size) model = SAGE(in_size, hidden_size, out_size)
assert len(args.fanout) == len(model.layers)
model = model.to(args.device)
# Model training. # Model training.
print("Training...") print("Training...")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment