Unverified Commit 0c70bc23 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Sparse] Refactor SIGN model. (#4921)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 4d641aba
......@@ -57,24 +57,31 @@ def evaluate(g, pred):
return val_acc, test_acc
def train(g, model):
labels = g.ndata["label"]
def train(g, X_sign, model):
label = g.ndata["label"]
train_mask = g.ndata["train_mask"]
optimizer = Adam(model.parameters(), lr=3e-3)
for epoch in range(10):
# Switch the model to training mode.
model.train()
# Forward.
logits = model(X_sign)
# Compute loss with nodes in training set.
loss = F.cross_entropy(logits[train_mask], labels[train_mask])
loss = F.cross_entropy(logits[train_mask], label[train_mask])
# Backward.
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Switch the model to evaluating mode.
model.eval()
# Compute prediction.
logits = model(X_sign)
pred = logits.argmax(1)
# Evaluate the prediction.
......@@ -117,4 +124,4 @@ if __name__ == "__main__":
model = SIGN(in_size, out_size, r).to(dev)
# Kick off training.
train(g, model)
train(g, X_sign, model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment