Commit 07932207 authored by rusty1s's avatar rusty1s
Browse files

restructure

parent ec590171
...@@ -105,6 +105,7 @@ class GIN(ScalableGNN): ...@@ -105,6 +105,7 @@ class GIN(ScalableGNN):
@torch.no_grad() @torch.no_grad()
def forward_layer(self, layer: int, x: Tensor, adj_t: SparseTensor, state): def forward_layer(self, layer: int, x: Tensor, adj_t: SparseTensor, state):
# Perform layer-wise inference:
if layer == 0: if layer == 0:
x = self.lins[0](x).relu_() x = self.lins[0](x).relu_()
...@@ -148,6 +149,13 @@ def train(model, loader, optimizer): ...@@ -148,6 +149,13 @@ def train(model, loader, optimizer):
return total_loss / total_examples return total_loss / total_examples
@torch.no_grad()
def mini_test(model, loader, y):
model.eval()
out = model(loader=loader)
return int((out.argmax(dim=-1) == y).sum()) / y.size(0)
@torch.no_grad() @torch.no_grad()
def full_test(model, loader): def full_test(model, loader):
model.eval() model.eval()
...@@ -162,13 +170,6 @@ def full_test(model, loader): ...@@ -162,13 +170,6 @@ def full_test(model, loader):
return total_correct / total_examples return total_correct / total_examples
@torch.no_grad()
def mini_test(model, loader, y):
model.eval()
out = model(loader=loader)
return int((out.argmax(dim=-1) == y).sum()) / y.size(0)
mini_test(model, eval_loader, data.y) # Fill history. mini_test(model, eval_loader, data.y) # Fill history.
for epoch in range(1, 151): for epoch in range(1, 151):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment