Unverified Commit 4b6faecb authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[Misc] Modify two examples to make them consistent (#6513)

parent 96226c61
...@@ -181,7 +181,7 @@ class SAGE(nn.Module): ...@@ -181,7 +181,7 @@ class SAGE(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.out_size = out_size self.out_size = out_size
# Set the dtype for the layers manually. # Set the dtype for the layers manually.
self.set_layer_dtype(torch.float64) self.set_layer_dtype(torch.float32)
def set_layer_dtype(self, _dtype): def set_layer_dtype(self, _dtype):
for layer in self.layers: for layer in self.layers:
...@@ -221,7 +221,7 @@ class SAGE(nn.Module): ...@@ -221,7 +221,7 @@ class SAGE(nn.Module):
for step, data in tqdm(enumerate(dataloader)): for step, data in tqdm(enumerate(dataloader)):
x = feature[data.input_nodes] x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1 hidden_x = layer(data.blocks[0], x.float()) # len(blocks) = 1
if not is_last_layer: if not is_last_layer:
hidden_x = F.relu(hidden_x) hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x) hidden_x = self.dropout(hidden_x)
...@@ -266,7 +266,7 @@ def evaluate(args, model, graph, features, itemset, num_classes): ...@@ -266,7 +266,7 @@ def evaluate(args, model, graph, features, itemset, num_classes):
for step, data in tqdm(enumerate(dataloader)): for step, data in tqdm(enumerate(dataloader)):
x = data.node_features["feat"] x = data.node_features["feat"]
y.append(data.labels) y.append(data.labels)
y_hats.append(model(data.blocks, x)) y_hats.append(model(data.blocks, x.float()))
return MF.accuracy( return MF.accuracy(
torch.cat(y_hats), torch.cat(y_hats),
...@@ -286,7 +286,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model): ...@@ -286,7 +286,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
t0 = time.time() t0 = time.time()
model.train() model.train()
total_loss = 0 total_loss = 0
for step, data in tqdm(enumerate(dataloader)): for step, data in enumerate(dataloader):
# The input features from the source nodes in the first layer's # The input features from the source nodes in the first layer's
# computation graph. # computation graph.
x = data.node_features["feat"] x = data.node_features["feat"]
...@@ -295,7 +295,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model): ...@@ -295,7 +295,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
# in the last layer's computation graph. # in the last layer's computation graph.
y = data.labels y = data.labels
y_hat = model(data.blocks, x) y_hat = model(data.blocks, x.float())
# Compute loss. # Compute loss.
loss = F.cross_entropy(y_hat, y) loss = F.cross_entropy(y_hat, y)
...@@ -399,7 +399,7 @@ def main(args): ...@@ -399,7 +399,7 @@ def main(args):
model, model,
num_classes, num_classes,
) )
print(f"Test Accuracy is {test_acc.item():.4f}") print(f"Test accuracy {test_acc.item():.4f}")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -35,6 +35,7 @@ main ...@@ -35,6 +35,7 @@ main
""" """
import argparse import argparse
import time
import dgl import dgl
import dgl.nn as dglnn import dgl.nn as dglnn
...@@ -228,6 +229,7 @@ def train(args, device, g, dataset, model, num_classes, use_uva): ...@@ -228,6 +229,7 @@ def train(args, device, g, dataset, model, num_classes, use_uva):
opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4) opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
for epoch in range(10): for epoch in range(10):
t0 = time.time()
model.train() model.train()
total_loss = 0 total_loss = 0
# A block is a graph consisting of two sets of nodes: the # A block is a graph consisting of two sets of nodes: the
...@@ -252,10 +254,11 @@ def train(args, device, g, dataset, model, num_classes, use_uva): ...@@ -252,10 +254,11 @@ def train(args, device, g, dataset, model, num_classes, use_uva):
loss.backward() loss.backward()
opt.step() opt.step()
total_loss += loss.item() total_loss += loss.item()
t1 = time.time()
acc = evaluate(model, g, val_dataloader, num_classes) acc = evaluate(model, g, val_dataloader, num_classes)
print( print(
f"Epoch {epoch:05d} | Loss {total_loss / (it + 1):.4f} | " f"Epoch {epoch:05d} | Loss {total_loss / (it + 1):.4f} | "
f"Accuracy {acc.item():.4f} " f"Accuracy {acc.item():.4f} | Time {t1 - t0:.4f}"
) )
...@@ -297,4 +300,4 @@ if __name__ == "__main__": ...@@ -297,4 +300,4 @@ if __name__ == "__main__":
acc = layerwise_infer( acc = layerwise_infer(
device, g, dataset.test_idx, model, num_classes, batch_size=4096 device, g, dataset.test_idx, model, num_classes, batch_size=4096
) )
print(f"Test Accuracy {acc.item():.4f}") print(f"Test accuracy {acc.item():.4f}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment