Unverified Commit b4cd60a9 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix relgraphconv bug (#3256)

parent 8341244a
......@@ -17,43 +17,43 @@ pip install requests torch rdflib pandas
Example code was tested with rdflib 4.2.2 and pandas 0.23.4
### Entity Classification
AIFB: accuracy 92.59% (3 runs, DGL), 95.83% (paper)
AIFB: accuracy 96.29% (3 runs, DGL), 95.83% (paper)
```
python3 entity_classify.py -d aifb --testing --gpu 0
```
MUTAG: accuracy 72.55% (3 runs, DGL), 73.23% (paper)
MUTAG: accuracy 70.59% (3 runs, DGL), 73.23% (paper)
```
python3 entity_classify.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gpu 0
```
BGS: accuracy 89.66% (3 runs, DGL), 83.10% (paper)
BGS: accuracy 93.10% (3 runs, DGL), 83.10% (paper)
```
python3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0
```
AM: accuracy 89.73% (3 runs, DGL), 89.29% (paper)
AM: accuracy 89.22% (3 runs, DGL), 89.29% (paper)
```
python3 entity_classify.py -d am --n-bases=40 --n-hidden=10 --l2norm=5e-4 --testing
```
### Entity Classification with minibatch
AIFB: accuracy avg(5 runs) 90.56%, best 94.44% (DGL)
AIFB: accuracy avg(5 runs) 90.00%, best 94.44% (DGL)
```
python3 entity_classify_mp.py -d aifb --testing --gpu 0 --fanout='20,20' --batch-size 128
```
MUTAG: accuracy avg(10 runs) 69.41%, best 76.47% (DGL)
MUTAG: accuracy avg(10 runs) 62.94%, best 72.06% (DGL)
```
python3 entity_classify_mp.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gpu 0 --batch-size 64 --fanout "-1, -1" --use-self-loop --dgl-sparse --n-epochs 20 --sparse-lr 0.01 --dropout 0.5
```
BGS: accuracy avg(5 runs) 85.52%, best 93.10% (DGL)
BGS: accuracy avg(5 runs) 78.62%, best 86.21% (DGL)
```
python3 entity_classify_mp.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --fanout "-1, -1" --n-epochs=16 --batch-size=16 --dgl-sparse --lr 0.01 --sparse-lr 0.05 --dropout 0.3
```
AM: accuracy avg(5 runs) 88.59%, best 88.89% (DGL)
AM: accuracy avg(5 runs) 87.37%, best 89.9% (DGL)
```
python3 entity_classify_mp.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --fanout '35,35' --batch-size 64 --n-hidden 16 --use-self-loop --n-epochs=20 --dgl-sparse --lr 0.01 --sparse-lr 0.02 --dropout 0.7
```
......
......@@ -218,8 +218,9 @@ class RelGraphConv(nn.Module):
if isinstance(etypes, list):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device))
idim = weight.shape[1]
weight = weight.view(-1, weight.shape[2])
flatidx = etypes * weight.shape[1] + h
flatidx = etypes * idim + h
msg = weight.index_select(0, flatidx)
elif self.low_mem:
# A more memory-friendly implementation.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment