Unverified Commit 0435b74c authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Hotfix] Fix layer norm (#2119)



* hotfix

* Fix Layer Norm
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-87-240.ec2.internal>
parent 715b3b16
......@@ -63,10 +63,10 @@ Test-bd: P3-8xlarge
OGBN-MAG accuracy 46.22
```
python3 entity_classify_mp.py -d ogbn-mag --testing --fanout='25,30' --batch-size 512 --n-hidden 64 --lr 0.01 --num-worker 0 --eval-batch-size 8 --low-mem --gpu 0,1,2,3,4,5,6,7 --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --mix-cpu-gpu --node-feats --layer-norm
python3 entity_classify_mp.py -d ogbn-mag --testing --fanout='25,30' --batch-size 512 --n-hidden 64 --lr 0.01 --num-worker 0 --eval-batch-size 8 --low-mem --gpu 0,1,2,3,4,5,6,7 --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --mix-cpu-gpu --node-feats
```
OGBN-MAG without node-feats 43.24
OGBN-MAG without node-feats 43.63
```
python3 entity_classify_mp.py -d ogbn-mag --testing --fanout='25,25' --batch-size 256 --n-hidden 64 --lr 0.01 --num-worker 0 --eval-batch-size 8 --low-mem --gpu 0,1,2,3,4,5,6,7 --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --mix-cpu-gpu --layer-norm
```
......
......@@ -86,19 +86,19 @@ class EntityClassify(nn.Module):
self.layers.append(RelGraphConv(
self.h_dim, self.h_dim, self.num_rels, "basis",
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
low_mem=self.low_mem, dropout=self.dropout))
low_mem=self.low_mem, dropout=self.dropout, layer_norm = layer_norm))
# h2h
for idx in range(self.num_hidden_layers):
self.layers.append(RelGraphConv(
self.h_dim, self.h_dim, self.num_rels, "basis",
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
low_mem=self.low_mem, dropout=self.dropout))
low_mem=self.low_mem, dropout=self.dropout, layer_norm = layer_norm))
# h2o
self.layers.append(RelGraphConv(
self.h_dim, self.out_dim, self.num_rels, "basis",
self.num_bases, activation=None,
self_loop=self.use_self_loop,
low_mem=self.low_mem))
low_mem=self.low_mem, layer_norm = layer_norm))
def forward(self, blocks, feats, norm=None):
if blocks is None:
......@@ -152,12 +152,10 @@ class NeighborSampler:
else:
frontier = dgl.sampling.sample_neighbors(self.g, cur, fanout)
etypes = self.g.edata[dgl.ETYPE][frontier.edata[dgl.EID]]
norm = self.g.edata['norm'][frontier.edata[dgl.EID]]
block = dgl.to_block(frontier, cur)
block.srcdata[dgl.NTYPE] = self.g.ndata[dgl.NTYPE][block.srcdata[dgl.NID]]
block.srcdata['type_id'] =self.g.ndata[dgl.NID][block.srcdata[dgl.NID]]
block.srcdata['type_id'] = self.g.ndata[dgl.NID][block.srcdata[dgl.NID]]
block.edata['etype'] = etypes
block.edata['norm'] = norm
cur = block.srcdata[dgl.NID]
blocks.insert(0, block)
return seeds, blocks
......
......@@ -175,7 +175,7 @@ class RelGraphConv(nn.Module):
# layer norm
if self.layer_norm:
self.layer_norm_weight = nn.LayerNorm(n_hidden, elementwise_affine=True)
self.layer_norm_weight = nn.LayerNorm(out_feat, elementwise_affine=True)
# weight for self loop
if self.self_loop:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment