Unverified Commit dcf46412 authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[bugfix] Fix dataloader bug in GAT PPI data example (#1966)



* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* fix dataloader bug
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent 18bfec24
...@@ -96,12 +96,9 @@ def main(args): ...@@ -96,12 +96,9 @@ def main(args):
if epoch % 5 == 0: if epoch % 5 == 0:
score_list = [] score_list = []
val_loss_list = [] val_loss_list = []
for batch, valid_data in enumerate(valid_dataloader): for batch, subgraph in enumerate(valid_dataloader):
subgraph, feats, labels = valid_data
subgraph = subgraph.to(device) subgraph = subgraph.to(device)
feats = feats.to(device) score, val_loss = evaluate(subgraph.ndata['feat'], model, subgraph, subgraph.ndata['label'], loss_fcn)
labels = labels.to(device)
score, val_loss = evaluate(feats.float(), model, subgraph, labels.float(), loss_fcn)
score_list.append(score) score_list.append(score)
val_loss_list.append(val_loss) val_loss_list.append(val_loss)
mean_score = np.array(score_list).mean() mean_score = np.array(score_list).mean()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment