Unverified Commit 2d8d6fbb authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[Misc] Modify two node examples to make them consitent. (#6599)

parent a664b0c4
......@@ -207,7 +207,7 @@ class SAGE(nn.Module):
y = torch.empty(
graph.total_num_nodes,
self.out_size if is_last_layer else self.hidden_size,
dtype=torch.float64,
dtype=torch.float32,
device=buffer_device,
pin_memory=pin_memory,
)
......@@ -215,7 +215,7 @@ class SAGE(nn.Module):
for step, data in tqdm(enumerate(dataloader)):
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x.float()) # len(blocks) = 1
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
......@@ -274,7 +274,7 @@ def evaluate(args, model, graph, features, itemset, num_classes):
for step, data in tqdm(enumerate(dataloader)):
x = data.node_features["feat"]
y.append(data.labels)
y_hats.append(model(data.blocks, x.float()))
y_hats.append(model(data.blocks, x))
return MF.accuracy(
torch.cat(y_hats),
......@@ -310,7 +310,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
# in the last layer's computation graph.
y = data.labels
y_hat = model(data.blocks, x.float())
y_hat = model(data.blocks, x)
# Compute loss.
loss = F.cross_entropy(y_hat, y)
......
......@@ -274,10 +274,16 @@ if __name__ == "__main__":
parser.add_argument(
"--mode",
default="mixed",
choices=["cpu", "mixed", "gpu", "compare-to-graphbolt"],
choices=["cpu", "mixed", "gpu"],
help="Training mode. 'cpu' for CPU training, 'mixed' for "
"CPU-GPU mixed training, 'gpu' for pure-GPU training.",
)
parser.add_argument(
"--compare-to-graphbolt",
default="false",
choices=["false", "true"],
help="Whether comparing to GraphBolt or not, 'false' by default.",
)
args = parser.parse_args()
if not torch.cuda.is_available():
args.mode = "cpu"
......@@ -286,13 +292,15 @@ if __name__ == "__main__":
# Load and preprocess dataset.
print("Loading data")
dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
g = dataset[0]
if args.compare_to_graphbolt == "false":
g = g.to("cuda" if args.mode == "gpu" else "cpu")
num_classes = dataset.num_classes
# Whether use Unified Virtual Addressing (UVA) for CUDA computation.
use_uva = args.mode == "mixed"
device = torch.device("cpu" if args.mode == "cpu" else "cuda")
fused_sampling = args.mode != "compare-to-graphbolt"
fused_sampling = args.compare_to_graphbolt == "false"
# Create GraphSAGE model.
in_size = g.ndata["feat"].shape[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment