Unverified Commit 18eaad17 authored by Hengrui Zhang's avatar Hengrui Zhang Committed by GitHub
Browse files

fix bugs (#3008)


Co-authored-by: default avatarzhjwy9343 <6593865@qq.com>
parent 36c6c649
...@@ -68,6 +68,7 @@ This example was implemented by [Hengrui Zhang](https://github.com/hengruizhang9 ...@@ -68,6 +68,7 @@ This example was implemented by [Hengrui Zhang](https://github.com/hengruizhang9
--wd2 float Weight decay of linear classifier. Default is 0.0. --wd2 float Weight decay of linear classifier. Default is 0.0.
--epsilon float Edge mask threshold. Default is 0.01. --epsilon float Edge mask threshold. Default is 0.01.
--hid_dim int Embedding dimension. Default is 512. --hid_dim int Embedding dimension. Default is 512.
--sample_size int Subgraph size. Default is 2000.
``` ```
## How to run examples ## How to run examples
......
...@@ -126,4 +126,4 @@ if __name__ == '__main__': ...@@ -126,4 +126,4 @@ if __name__ == '__main__':
accs.append(acc * 100) accs.append(acc * 100)
accs = th.stack(accs) accs = th.stack(accs)
print(accs.mean().item(), accs.std().item()) print(accs.mean().item(), accs.std().item())
\ No newline at end of file
...@@ -23,6 +23,7 @@ parser.add_argument('--wd1', type=float, default=0., help='Weight decay of mvgrl ...@@ -23,6 +23,7 @@ parser.add_argument('--wd1', type=float, default=0., help='Weight decay of mvgrl
parser.add_argument('--wd2', type=float, default=0., help='Weight decay of linear evaluator.') parser.add_argument('--wd2', type=float, default=0., help='Weight decay of linear evaluator.')
parser.add_argument('--epsilon', type=float, default=0.01, help='Edge mask threshold of diffusion graph.') parser.add_argument('--epsilon', type=float, default=0.01, help='Edge mask threshold of diffusion graph.')
parser.add_argument("--hid_dim", type=int, default=512, help='Hidden layer dim.') parser.add_argument("--hid_dim", type=int, default=512, help='Hidden layer dim.')
parser.add_argument("--sample_size", type=int, default=2000, help='Subgraph size.')
args = parser.parse_args() args = parser.parse_args()
...@@ -54,7 +55,7 @@ if __name__ == '__main__': ...@@ -54,7 +55,7 @@ if __name__ == '__main__':
n_node = graph.number_of_nodes() n_node = graph.number_of_nodes()
sample_size = 2000 sample_size = args.sample_size
lbl1 = th.ones(sample_size * 2) lbl1 = th.ones(sample_size * 2)
lbl2 = th.zeros(sample_size * 2) lbl2 = th.zeros(sample_size * 2)
...@@ -153,4 +154,4 @@ if __name__ == '__main__': ...@@ -153,4 +154,4 @@ if __name__ == '__main__':
accs.append(acc * 100) accs.append(acc * 100)
accs = th.stack(accs) accs = th.stack(accs)
print(accs.mean().item(), accs.std().item()) print(accs.mean().item(), accs.std().item())
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment