Unverified Commit 9a6b81ef authored by nxznm's avatar nxznm Committed by GitHub
Browse files

[Bugfix] Improve CompGCN (#3663)


Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent d3930bab
...@@ -9,9 +9,9 @@ This example was implemented by [zhjwy9343](https://github.com/zhjwy9343) and [K ...@@ -9,9 +9,9 @@ This example was implemented by [zhjwy9343](https://github.com/zhjwy9343) and [K
Dependencies Dependencies
---------------------- ----------------------
- pytorch 1.7.1 - pytorch 1.9.0
- dgl 0.6.0 - dgl 0.7.1
- numpy 1.19.4 - numpy 1.20.3
- ordered_set 4.0.2 - ordered_set 4.0.2
Dataset Dataset
...@@ -67,11 +67,11 @@ Performance ...@@ -67,11 +67,11 @@ Performance
| Dataset | FB15k-237 | WN18RR | | Dataset | FB15k-237 | WN18RR |
|---------| ------------------------ | ------------------------ | |---------| ------------------------ | ------------------------ |
| Metric | Paper / ours (dgl) | Paper / ours (dgl) | | Metric | Paper / ours (dgl) | Paper / ours (dgl) |
| MRR | 0.355 / 0.349 | 0.479 / 0.471 | | MRR | 0.355 / 0.348 | 0.479 / 0.466 |
| MR | 197 / 208 | 3533 / 3550 | | MR | 197 / 208 | 3533 / 3542 |
| Hit@10 | 0.535 / 0.526 | 0.546 / 0.532 | | Hit@10 | 0.535 / 0.527 | 0.546 / 0.525 |
| Hit@3 | 0.390 / 0.381 | 0.494 / 0.480 | | Hit@3 | 0.390 / 0.380 | 0.494 / 0.476 |
| Hit@1 | 0.264 / 0.260 | 0.443 / 0.438 | | Hit@1 | 0.264 / 0.259 | 0.443 / 0.435 |
......
...@@ -120,7 +120,7 @@ def main(args): ...@@ -120,7 +120,7 @@ def main(args):
# compute loss # compute loss
tr_loss = loss_fn(logits, label) tr_loss = loss_fn(logits, label)
train_loss.append(tr_loss) train_loss.append(tr_loss.item())
# backward # backward
optimizer.zero_grad() optimizer.zero_grad()
...@@ -142,7 +142,7 @@ def main(args): ...@@ -142,7 +142,7 @@ def main(args):
print("saving model...") print("saving model...")
else: else:
kill_cnt += 1 kill_cnt += 1
if kill_cnt > 25: if kill_cnt > 100:
print('early stop.') print('early stop.')
break break
print("In epoch {}, Train Loss: {:.4f}, Valid MRR: {:.5}\n, Train time: {}, Valid time: {}"\ print("In epoch {}, Train Loss: {:.4f}, Valid MRR: {:.5}\n, Train time: {}, Valid time: {}"\
...@@ -164,7 +164,7 @@ if __name__ == '__main__': ...@@ -164,7 +164,7 @@ if __name__ == '__main__':
parser.add_argument('--score_func', dest='score_func', default='conve', help='Score Function for Link prediction') parser.add_argument('--score_func', dest='score_func', default='conve', help='Score Function for Link prediction')
parser.add_argument('--opn', dest='opn', default='ccorr', help='Composition Operation to be used in CompGCN') parser.add_argument('--opn', dest='opn', default='ccorr', help='Composition Operation to be used in CompGCN')
parser.add_argument('--batch', dest='batch_size', default=128, type=int, help='Batch size') parser.add_argument('--batch', dest='batch_size', default=1024, type=int, help='Batch size')
parser.add_argument('--gpu', type=int, default='0', help='Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0') parser.add_argument('--gpu', type=int, default='0', help='Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0')
parser.add_argument('--epoch', dest='max_epochs', type=int, default=500, help='Number of epochs') parser.add_argument('--epoch', dest='max_epochs', type=int, default=500, help='Number of epochs')
parser.add_argument('--l2', type=float, default=0.0, help='L2 Regularization for Optimizer') parser.add_argument('--l2', type=float, default=0.0, help='L2 Regularization for Optimizer')
......
...@@ -32,7 +32,7 @@ def ccorr(a, b): ...@@ -32,7 +32,7 @@ def ccorr(a, b):
------- -------
Tensor, having the same dimension as the input a. Tensor, having the same dimension as the input a.
""" """
return th.irfft(com_mult(conj(th.rfft(a, 1)), th.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],)) return th.fft.irfftn(th.conj(th.fft.rfftn(a, (-1))) * th.fft.rfftn(b, (-1)), (-1))
#identify in/out edges, compute edge norm for each and store in edata #identify in/out edges, compute edge norm for each and store in edata
def in_out_norm(graph): def in_out_norm(graph):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment