Unverified Commit 0c70bc23 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Sparse] Refactor SIGN model. (#4921)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 4d641aba
...@@ -57,24 +57,31 @@ def evaluate(g, pred): ...@@ -57,24 +57,31 @@ def evaluate(g, pred):
return val_acc, test_acc return val_acc, test_acc
def train(g, model): def train(g, X_sign, model):
labels = g.ndata["label"] label = g.ndata["label"]
train_mask = g.ndata["train_mask"] train_mask = g.ndata["train_mask"]
optimizer = Adam(model.parameters(), lr=3e-3) optimizer = Adam(model.parameters(), lr=3e-3)
for epoch in range(10): for epoch in range(10):
# Switch the model to training mode.
model.train()
# Forward. # Forward.
logits = model(X_sign) logits = model(X_sign)
# Compute loss with nodes in training set. # Compute loss with nodes in training set.
loss = F.cross_entropy(logits[train_mask], labels[train_mask]) loss = F.cross_entropy(logits[train_mask], label[train_mask])
# Backward. # Backward.
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# Switch the model to evaluating mode.
model.eval()
# Compute prediction. # Compute prediction.
logits = model(X_sign)
pred = logits.argmax(1) pred = logits.argmax(1)
# Evaluate the prediction. # Evaluate the prediction.
...@@ -117,4 +124,4 @@ if __name__ == "__main__": ...@@ -117,4 +124,4 @@ if __name__ == "__main__":
model = SIGN(in_size, out_size, r).to(dev) model = SIGN(in_size, out_size, r).to(dev)
# Kick off training. # Kick off training.
train(g, model) train(g, X_sign, model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment