Unverified Commit d912947b authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable gpu train for ogbn-mag (#6373)

parent cef5a14a
......@@ -4,18 +4,25 @@ This example aims to demonstrate how to run node classification task on heteroge
## Run on `ogbn-mag` dataset
### Command
### Sample on CPU and train/infer on CPU
```
python3 hetero_rgcn.py
python3 hetero_rgcn.py --dataset ogbn-mag
```
### Statistics of train/validation/test
Below results are run on AWS EC2 r6idn.metal, 1024GB RAM, 128 vCPUs(Ice Lake 8375C), 0 GPUs.
### Sample on CPU and train/infer on GPU
```
python3 hetero_rgcn.py --dataset ogbn-mag --num_gups 1
```
### Resource usage and time cost
Below results are roughly collected from an AWS EC2 **g4dn.metal**, 384GB RAM, 96 vCPUs(Cascade Lake P-8259L), 8 NVIDIA T4 GPUs.
| Dataset Size | Peak CPU RAM Usage | Time Per Epoch(Training) | Time Per Epoch(Inference: train/val/test set) |
| ------------ | ------------- | ------------------------ | --------------------------- |
| ~1.1GB | ~5GB | ~3min | ~1min40s + ~0min9s + ~0min7s |
| Dataset Size | CPU RAM Usage | Num of GPUs | GPU RAM Usage | Time Per Epoch(Training) | Time Per Epoch(Inference: train/val/test set) |
| ------------ | ------------- | ----------- | ---------- | --------- | --------------------------- |
| ~1.1GB | ~5GB | 0 | 0GB | ~4min5s | ~2min7s + ~0min12s + ~0min8s |
| ~1.1GB | ~4.3GB | 1 | 4.7GB | ~1min18s | ~1min54s + ~0min12s + ~0min8s |
### Accuracies
```
Final performance:
All runs:
......
......@@ -76,7 +76,6 @@ def load_dataset(dataset_name):
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
num_classes = dataset.tasks[0].metadata["num_classes"]
print(len(train_set), len(valid_set), len(test_set))
return graph, features, train_set, valid_set, test_set, num_classes
......@@ -127,6 +126,8 @@ def create_dataloader(
# Move the mini-batch to the appropriate device.
# `device`:
# The device to move the mini-batch to.
# [TODO] Moving `MiniBatch` to GPU is not supported yet.
device = th.device("cpu")
datapipe = datapipe.copy_to(device)
# Create a DataLoader from the datapipe.
......@@ -424,11 +425,14 @@ class Logger(object):
def extract_node_features(name, block, data, node_embed, device):
"""Extract the node features from embedding layer or raw features."""
if name == "ogbn-mag":
input_nodes = {k: v.to(device) for k, v in data.input_nodes.items()}
# Extract node embeddings for the input nodes.
node_features = extract_embed(node_embed, data.input_nodes)
node_features = extract_embed(node_embed, input_nodes)
# Add the batch's raw "paper" features. Corresponds to the content
# in the function `rel_graph_embed` comment.
node_features.update({"paper": data.node_features[("paper", "feat")]})
node_features.update(
{"paper": data.node_features[("paper", "feat")].to(device)}
)
else:
node_features = {
ntype: block.srcnodes[ntype].data["feat"]
......@@ -491,7 +495,7 @@ def evaluate(
y_true = list()
for data in tqdm(data_loader, desc="Inference"):
blocks = data.to_dgl_blocks()
blocks = [block.to(device) for block in data.to_dgl_blocks()]
node_features = extract_node_features(
name, blocks[0], data, node_embed, device
)
......@@ -503,7 +507,7 @@ def evaluate(
# argmax.
y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True)
y_hats.append(y_hat.cpu())
y_true.append(data.labels[category].long().cpu())
y_true.append(data.labels[category].long())
y_pred = th.cat(y_hats, dim=0)
y_true = th.cat(y_true, dim=0)
......@@ -562,7 +566,7 @@ def run(
num_seeds = data.seed_nodes[category].shape[0]
# Convert MiniBatch to DGL Blocks.
blocks = data.to_dgl_blocks()
blocks = [block.to(device) for block in data.to_dgl_blocks()]
# Extract the node features from embedding layer or raw features.
node_features = extract_node_features(
......@@ -574,7 +578,7 @@ def run(
# Generate predictions.
logits = model(node_features, blocks)[category]
y_hat = logits.log_softmax(dim=-1)
y_hat = logits.log_softmax(dim=-1).cpu()
loss = F.nll_loss(y_hat, data.labels[category].long())
loss.backward()
optimizer.step()
......@@ -625,9 +629,7 @@ def run(
def main(args):
if args.gpu > 0:
raise RuntimeError("GPU training is not supported.")
device = th.device("cpu")
device = th.device("cuda") if args.num_gpus > 0 else th.device("cpu")
# Initialize a logger.
logger = Logger(args.runs)
......@@ -729,7 +731,7 @@ if __name__ == "__main__":
)
parser.add_argument("--runs", type=int, default=5)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--num_gpus", type=int, default=0)
args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment