Unverified Commit 87310503 authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

[Example] Fix GIN Example (#826)

* Add serialization

* fix gin

* Add serialization

* fix gin

* fix
parent ddeb86f9
...@@ -20,12 +20,12 @@ How to run ...@@ -20,12 +20,12 @@ How to run
An experiment on the GIN in default settings can be run with An experiment on the GIN in default settings can be run with
```bash ```bash
python3 main.py python main.py
``` ```
An experiment on the GIN in customized settings can be run with An experiment on the GIN in customized settings can be run with
```bash ```bash
python3 main.py [--device 0 | --disable-cuda] --dataset COLLAB \ python main.py [--device 0 | --disable-cuda] --dataset COLLAB \
--graph_pooling_type max --neighbor_pooling_type sum --graph_pooling_type max --neighbor_pooling_type sum
``` ```
...@@ -35,7 +35,7 @@ Results ...@@ -35,7 +35,7 @@ Results
Run with following with the double SUM pooling way: Run with following with the double SUM pooling way:
(tested dataset: "MUTAG"(default), "COLLAB", "IMDBBINARY", "IMDBMULTI") (tested dataset: "MUTAG"(default), "COLLAB", "IMDBBINARY", "IMDBMULTI")
```bash ```bash
python3 train.py --dataset MUTAB --device 0 \ python main.py --dataset MUTAG --device 0 \
--graph_pooling_type sum --neighbor_pooling_type sum --graph_pooling_type sum --neighbor_pooling_type sum
``` ```
......
...@@ -156,6 +156,8 @@ class GIN(nn.Module): ...@@ -156,6 +156,8 @@ class GIN(nn.Module):
for layer in range(self.num_layers - 1): for layer in range(self.num_layers - 1):
h = self.ginlayers[layer](g, h) h = self.ginlayers[layer](g, h)
h = self.batch_norms[layer](h)
h = F.relu(h)
hidden_rep.append(h) hidden_rep.append(h)
score_over_layer = 0 score_over_layer = 0
......
...@@ -147,8 +147,8 @@ def main(args): ...@@ -147,8 +147,8 @@ def main(args):
f.write("\n") f.write("\n")
lrbar.set_description( lrbar.set_description(
"the learning eps with learn_eps={} is: {}".format( "Learning eps with learn_eps={}: {}".format(
args.learn_eps, [layer.eps.data for layer in model.ginlayers])) args.learn_eps, [layer.eps.data.item() for layer in model.ginlayers]))
tbar.close() tbar.close()
vbar.close() vbar.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment