Unverified Commit d912947b authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable gpu train for ogbn-mag (#6373)

parent cef5a14a
...@@ -4,18 +4,25 @@ This example aims to demonstrate how to run node classification task on heteroge ...@@ -4,18 +4,25 @@ This example aims to demonstrate how to run node classification task on heteroge
## Run on `ogbn-mag` dataset ## Run on `ogbn-mag` dataset
### Command ### Sample on CPU and train/infer on CPU
``` ```
python3 hetero_rgcn.py python3 hetero_rgcn.py --dataset ogbn-mag
``` ```
### Statistics of train/validation/test ### Sample on CPU and train/infer on GPU
Below results are run on AWS EC2 r6idn.metal, 1024GB RAM, 128 vCPUs(Ice Lake 8375C), 0 GPUs. ```
python3 hetero_rgcn.py --dataset ogbn-mag --num_gups 1
```
| Dataset Size | Peak CPU RAM Usage | Time Per Epoch(Training) | Time Per Epoch(Inference: train/val/test set) | ### Resource usage and time cost
| ------------ | ------------- | ------------------------ | --------------------------- | Below results are roughly collected from an AWS EC2 **g4dn.metal**, 384GB RAM, 96 vCPUs(Cascade Lake P-8259L), 8 NVIDIA T4 GPUs.
| ~1.1GB | ~5GB | ~3min | ~1min40s + ~0min9s + ~0min7s |
| Dataset Size | CPU RAM Usage | Num of GPUs | GPU RAM Usage | Time Per Epoch(Training) | Time Per Epoch(Inference: train/val/test set) |
| ------------ | ------------- | ----------- | ---------- | --------- | --------------------------- |
| ~1.1GB | ~5GB | 0 | 0GB | ~4min5s | ~2min7s + ~0min12s + ~0min8s |
| ~1.1GB | ~4.3GB | 1 | 4.7GB | ~1min18s | ~1min54s + ~0min12s + ~0min8s |
### Accuracies
``` ```
Final performance: Final performance:
All runs: All runs:
......
...@@ -76,7 +76,6 @@ def load_dataset(dataset_name): ...@@ -76,7 +76,6 @@ def load_dataset(dataset_name):
valid_set = dataset.tasks[0].validation_set valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set test_set = dataset.tasks[0].test_set
num_classes = dataset.tasks[0].metadata["num_classes"] num_classes = dataset.tasks[0].metadata["num_classes"]
print(len(train_set), len(valid_set), len(test_set))
return graph, features, train_set, valid_set, test_set, num_classes return graph, features, train_set, valid_set, test_set, num_classes
...@@ -127,6 +126,8 @@ def create_dataloader( ...@@ -127,6 +126,8 @@ def create_dataloader(
# Move the mini-batch to the appropriate device. # Move the mini-batch to the appropriate device.
# `device`: # `device`:
# The device to move the mini-batch to. # The device to move the mini-batch to.
# [TODO] Moving `MiniBatch` to GPU is not supported yet.
device = th.device("cpu")
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
# Create a DataLoader from the datapipe. # Create a DataLoader from the datapipe.
...@@ -424,11 +425,14 @@ class Logger(object): ...@@ -424,11 +425,14 @@ class Logger(object):
def extract_node_features(name, block, data, node_embed, device): def extract_node_features(name, block, data, node_embed, device):
"""Extract the node features from embedding layer or raw features.""" """Extract the node features from embedding layer or raw features."""
if name == "ogbn-mag": if name == "ogbn-mag":
input_nodes = {k: v.to(device) for k, v in data.input_nodes.items()}
# Extract node embeddings for the input nodes. # Extract node embeddings for the input nodes.
node_features = extract_embed(node_embed, data.input_nodes) node_features = extract_embed(node_embed, input_nodes)
# Add the batch's raw "paper" features. Corresponds to the content # Add the batch's raw "paper" features. Corresponds to the content
# in the function `rel_graph_embed` comment. # in the function `rel_graph_embed` comment.
node_features.update({"paper": data.node_features[("paper", "feat")]}) node_features.update(
{"paper": data.node_features[("paper", "feat")].to(device)}
)
else: else:
node_features = { node_features = {
ntype: block.srcnodes[ntype].data["feat"] ntype: block.srcnodes[ntype].data["feat"]
...@@ -491,7 +495,7 @@ def evaluate( ...@@ -491,7 +495,7 @@ def evaluate(
y_true = list() y_true = list()
for data in tqdm(data_loader, desc="Inference"): for data in tqdm(data_loader, desc="Inference"):
blocks = data.to_dgl_blocks() blocks = [block.to(device) for block in data.to_dgl_blocks()]
node_features = extract_node_features( node_features = extract_node_features(
name, blocks[0], data, node_embed, device name, blocks[0], data, node_embed, device
) )
...@@ -503,7 +507,7 @@ def evaluate( ...@@ -503,7 +507,7 @@ def evaluate(
# argmax. # argmax.
y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True) y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True)
y_hats.append(y_hat.cpu()) y_hats.append(y_hat.cpu())
y_true.append(data.labels[category].long().cpu()) y_true.append(data.labels[category].long())
y_pred = th.cat(y_hats, dim=0) y_pred = th.cat(y_hats, dim=0)
y_true = th.cat(y_true, dim=0) y_true = th.cat(y_true, dim=0)
...@@ -562,7 +566,7 @@ def run( ...@@ -562,7 +566,7 @@ def run(
num_seeds = data.seed_nodes[category].shape[0] num_seeds = data.seed_nodes[category].shape[0]
# Convert MiniBatch to DGL Blocks. # Convert MiniBatch to DGL Blocks.
blocks = data.to_dgl_blocks() blocks = [block.to(device) for block in data.to_dgl_blocks()]
# Extract the node features from embedding layer or raw features. # Extract the node features from embedding layer or raw features.
node_features = extract_node_features( node_features = extract_node_features(
...@@ -574,7 +578,7 @@ def run( ...@@ -574,7 +578,7 @@ def run(
# Generate predictions. # Generate predictions.
logits = model(node_features, blocks)[category] logits = model(node_features, blocks)[category]
y_hat = logits.log_softmax(dim=-1) y_hat = logits.log_softmax(dim=-1).cpu()
loss = F.nll_loss(y_hat, data.labels[category].long()) loss = F.nll_loss(y_hat, data.labels[category].long())
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -625,9 +629,7 @@ def run( ...@@ -625,9 +629,7 @@ def run(
def main(args): def main(args):
if args.gpu > 0: device = th.device("cuda") if args.num_gpus > 0 else th.device("cpu")
raise RuntimeError("GPU training is not supported.")
device = th.device("cpu")
# Initialize a logger. # Initialize a logger.
logger = Logger(args.runs) logger = Logger(args.runs)
...@@ -729,7 +731,7 @@ if __name__ == "__main__": ...@@ -729,7 +731,7 @@ if __name__ == "__main__":
) )
parser.add_argument("--runs", type=int, default=5) parser.add_argument("--runs", type=int, default=5)
parser.add_argument("--num_workers", type=int, default=0) parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--gpu", type=int, default=0) parser.add_argument("--num_gpus", type=int, default=0)
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment