configure.py 866 Bytes
Newer Older
1
2
3
4
"""Best hyperparameters found."""
import torch

MWE_GCN_proteins = {
5
6
7
8
9
10
11
12
13
14
15
16
17
    "num_ew_channels": 8,
    "num_epochs": 2000,
    "in_feats": 1,
    "hidden_feats": 10,
    "out_feats": 112,
    "n_layers": 3,
    "lr": 2e-2,
    "weight_decay": 0,
    "patience": 1000,
    "dropout": 0.2,
    "aggr_mode": "sum",  ## 'sum' or 'concat' for the aggregation across channels
    "ewnorm": "both",
}
18
19

MWE_DGCN_proteins = {
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    "num_ew_channels": 8,
    "num_epochs": 2000,
    "in_feats": 1,
    "hidden_feats": 10,
    "out_feats": 112,
    "n_layers": 2,
    "lr": 1e-2,
    "weight_decay": 0,
    "patience": 300,
    "dropout": 0.5,
    "aggr_mode": "sum",
    "residual": True,
    "ewnorm": "none",
}
34
35
36


def get_exp_configure(args):
37
    if args["model"] == "MWE-GCN":
38
        return MWE_GCN_proteins
39
    elif args["model"] == "MWE-DGCN":
40
        return MWE_DGCN_proteins