"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "69cdc25746d880279cb79b2018c7de04b8ecf89f"
Unverified Commit 26b63180 authored by skepsun's avatar skepsun Committed by GitHub
Browse files

[Bugfix] Fix Correct&Smooth (#3329)



* Update model.py

fix typo

* Update main.py

fix autoscale

* Update README.md
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 246df1a8
...@@ -32,14 +32,14 @@ Training a **Base predictor** and using **Correct&Smooth** which follows the ori ...@@ -32,14 +32,14 @@ Training a **Base predictor** and using **Correct&Smooth** which follows the ori
```bash ```bash
python main.py --dropout 0.5 python main.py --dropout 0.5
python main.py --pretrain --correction-adj DA --smoothing-adj AD python main.py --pretrain --correction-adj DA --smoothing-adj AD --autoscale
``` ```
* **Linear + C&S** * **Linear + C&S**
```bash ```bash
python main.py --model linear --dropout 0.5 --epochs 1000 python main.py --model linear --dropout 0.5 --epochs 1000
python main.py --model linear --pretrain --correction-alpha 0.8 --smoothing-alpha 0.6 --correction-adj AD python main.py --model linear --pretrain --correction-alpha 0.8 --smoothing-alpha 0.6 --correction-adj AD --autoscale
``` ```
##### ogbn-products ##### ogbn-products
...@@ -48,7 +48,7 @@ python main.py --model linear --pretrain --correction-alpha 0.8 --smoothing-alph ...@@ -48,7 +48,7 @@ python main.py --model linear --pretrain --correction-alpha 0.8 --smoothing-alph
```bash ```bash
python main.py --dataset ogbn-products --model linear --dropout 0.5 --epochs 1000 --lr 0.1 python main.py --dataset ogbn-products --model linear --dropout 0.5 --epochs 1000 --lr 0.1
python main.py --dataset ogbn-products --model linear --pretrain --correction-alpha 0.6 --smoothing-alpha 0.9 python main.py --dataset ogbn-products --model linear --pretrain --correction-alpha 1. --smoothing-alpha 0.9
``` ```
### Performance ### Performance
...@@ -58,14 +58,14 @@ python main.py --dataset ogbn-products --model linear --pretrain --correction-al ...@@ -58,14 +58,14 @@ python main.py --dataset ogbn-products --model linear --pretrain --correction-al
| | MLP | MLP + C&S | Linear | Linear + C&S | | | MLP | MLP + C&S | Linear | Linear + C&S |
| :-------------: | :---: | :-------: | :----: | :----------: | | :-------------: | :---: | :-------: | :----: | :----------: |
| Results(Author) | 55.58 | 68.72 | 51.06 | 70.24 | | Results(Author) | 55.58 | 68.72 | 51.06 | 70.24 |
| Results(DGL) | 56.12 | 68.63 | 52.49 | 71.69 | | Results(DGL) | 56.55 | 70.93 | 52.48 | 72.60 |
#### ogbn-products #### ogbn-products
| | Linear | Linear + C&S | | | Linear | Linear + C&S |
| :-------------: | :----: | :----------: | | :-------------: | :----: | :----------: |
| Results(Author) | 47.67 | 82.34 | | Results(Author) | 47.67 | 82.34 |
| Results(DGL) | 47.71 | 79.57 | | Results(DGL) | 47.65 | 82.86 |
### Speed ### Speed
......
...@@ -75,6 +75,7 @@ def main(): ...@@ -75,6 +75,7 @@ def main():
num_smoothing_layers=args.num_smoothing_layers, num_smoothing_layers=args.num_smoothing_layers,
smoothing_alpha=args.smoothing_alpha, smoothing_alpha=args.smoothing_alpha,
smoothing_adj=args.smoothing_adj, smoothing_adj=args.smoothing_adj,
autoscale=args.autoscale,
scale=args.scale) scale=args.scale)
mask_idx = torch.cat([train_idx, valid_idx]) mask_idx = torch.cat([train_idx, valid_idx])
...@@ -162,6 +163,7 @@ if __name__ == '__main__': ...@@ -162,6 +163,7 @@ if __name__ == '__main__':
parser.add_argument('--num-smoothing-layers', type=int, default=50) parser.add_argument('--num-smoothing-layers', type=int, default=50)
parser.add_argument('--smoothing-alpha', type=float, default=0.756) parser.add_argument('--smoothing-alpha', type=float, default=0.756)
parser.add_argument('--smoothing-adj', type=str, default='DAD') parser.add_argument('--smoothing-adj', type=str, default='DAD')
parser.add_argument('--autoscale', action='store_true')
parser.add_argument('--scale', type=float, default=20.) parser.add_argument('--scale', type=float, default=20.)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -163,7 +163,7 @@ class CorrectAndSmooth(nn.Module): ...@@ -163,7 +163,7 @@ class CorrectAndSmooth(nn.Module):
correction_adj) correction_adj)
self.prop2 = LabelPropagation(num_smoothing_layers, self.prop2 = LabelPropagation(num_smoothing_layers,
smoothing_alpha, smoothing_alpha,
correction_adj) smoothing_adj)
def correct(self, g, y_soft, y_true, mask): def correct(self, g, y_soft, y_true, mask):
with g.local_scope(): with g.local_scope():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment