Unverified Commit 2f47c241 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] fix incorrect call in example (#6563)

parent 07db09f7
...@@ -46,6 +46,7 @@ import argparse ...@@ -46,6 +46,7 @@ import argparse
import itertools import itertools
import sys import sys
import dgl
import dgl.graphbolt as gb import dgl.graphbolt as gb
import dgl.nn as dglnn import dgl.nn as dglnn
...@@ -428,7 +429,9 @@ class Logger(object): ...@@ -428,7 +429,9 @@ class Logger(object):
def extract_node_features(name, block, data, node_embed, device): def extract_node_features(name, block, data, node_embed, device):
"""Extract the node features from embedding layer or raw features.""" """Extract the node features from embedding layer or raw features."""
if name == "ogbn-mag": if name == "ogbn-mag":
input_nodes = {k: v.to(device) for k, v in data.input_nodes.items()} input_nodes = {
k: v.to(device) for k, v in block.srcdata[dgl.NID].items()
}
# Extract node embeddings for the input nodes. # Extract node embeddings for the input nodes.
node_features = extract_embed(node_embed, input_nodes) node_features = extract_embed(node_embed, input_nodes)
# Add the batch's raw "paper" features. Corresponds to the content # Add the batch's raw "paper" features. Corresponds to the content
...@@ -554,12 +557,12 @@ def run( ...@@ -554,12 +557,12 @@ def run(
num_workers=num_workers, num_workers=num_workers,
) )
for data in tqdm(data_loader, desc=f"Training~Epoch {epoch:02d}"): for data in tqdm(data_loader, desc=f"Training~Epoch {epoch:02d}"):
# Fetch the number of seed nodes in the batch.
num_seeds = data.output_nodes[category].shape[0]
# Convert MiniBatch to DGL Blocks. # Convert MiniBatch to DGL Blocks.
blocks = [block.to(device) for block in data.blocks] blocks = [block.to(device) for block in data.blocks]
# Fetch the number of seed nodes in the batch.
num_seeds = blocks[-1].num_dst_nodes(category)
# Extract the node features from embedding layer or raw features. # Extract the node features from embedding layer or raw features.
node_features = extract_node_features( node_features = extract_node_features(
name, blocks[0], data, node_embed, device name, blocks[0], data, node_embed, device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment