Unverified Commit 4c883d89 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[Misc] Replace api to get feature size in `node_classfication` and `hetero_rgcn`. (#6460)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 3774945d
...@@ -133,12 +133,11 @@ class SAGE(LightningModule): ...@@ -133,12 +133,11 @@ class SAGE(LightningModule):
class DataModule(LightningDataModule): class DataModule(LightningDataModule):
def __init__(self, fanouts, batch_size, num_workers): def __init__(self, dataset, fanouts, batch_size, num_workers):
super().__init__() super().__init__()
self.fanouts = fanouts self.fanouts = fanouts
self.batch_size = batch_size self.batch_size = batch_size
self.num_workers = num_workers self.num_workers = num_workers
dataset = gb.BuiltinDataset("ogbn-products").load()
self.feature_store = dataset.feature self.feature_store = dataset.feature
self.graph = dataset.graph self.graph = dataset.graph
self.train_set = dataset.tasks[0].train_set self.train_set = dataset.tasks[0].train_set
...@@ -209,8 +208,15 @@ if __name__ == "__main__": ...@@ -209,8 +208,15 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
datamodule = DataModule([15, 10, 5], args.batch_size, args.num_workers) dataset = gb.BuiltinDataset("ogbn-products").load()
model = SAGE(100, 256, datamodule.num_classes).to(torch.double) datamodule = DataModule(
dataset,
[15, 10, 5],
args.batch_size,
args.num_workers,
)
in_size = dataset.feature.size("node", None, "feat")[0]
model = SAGE(in_size, 256, datamodule.num_classes).to(torch.double)
# Train. # Train.
checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max") checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max")
......
...@@ -631,11 +631,7 @@ def main(args): ...@@ -631,11 +631,7 @@ def main(args):
args.dataset args.dataset
) )
# TODO: featch from ``feature store``. feat_size = features.size("node", None, "feat")[0]
if args.dataset == "ogbn-mag":
feat_size = 128
else:
feat_size = 768
# As `ogb-lsc-mag240m` is a large dataset, features of `author` and # As `ogb-lsc-mag240m` is a large dataset, features of `author` and
# `institution` are generated in advance and stored in the feature store. # `institution` are generated in advance and stored in the feature store.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment