Unverified Commit 6e8f7605 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Update (#2057)

parent 75e89a15
......@@ -62,7 +62,8 @@ class HANLayer(nn.Module):
self.gat_layers = nn.ModuleList()
for i in range(len(meta_paths)):
self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,
dropout, dropout, activation=F.elu))
dropout, dropout, activation=F.elu,
allow_zero_in_degree=True))
self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)
self.meta_paths = list(tuple(meta_path) for meta_path in meta_paths)
......
......@@ -99,7 +99,7 @@ def setup(args):
args.update(default_configure)
set_random_seed(args['seed'])
args['dataset'] = 'ACMRaw' if args['hetero'] else 'ACM'
args['device'] = 'cuda: 0' if torch.cuda.is_available() else 'cpu'
args['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu'
args['log_dir'] = setup_log_dir(args)
return args
......@@ -107,7 +107,7 @@ def setup_for_sampling(args):
args.update(default_configure)
args.update(sampling_configure)
set_random_seed()
args['device'] = 'cuda: 0' if torch.cuda.is_available() else 'cpu'
args['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu'
args['log_dir'] = setup_log_dir(args, sampling=True)
return args
......@@ -188,7 +188,7 @@ def load_acm_raw(remove_self_loop):
hg = dgl.heterograph({
('paper', 'pa', 'author'): p_vs_a.nonzero(),
('author', 'ap', 'paper'): p_vs_a.transpose.nonzero(),
('author', 'ap', 'paper'): p_vs_a.transpose().nonzero(),
('paper', 'pf', 'field'): p_vs_l.nonzero(),
('field', 'fp', 'paper'): p_vs_l.transpose().nonzero()
})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment