Unverified Commit 4c883d89 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[Misc] Replace api to get feature size in `node_classfication` and `hetero_rgcn`. (#6460)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 3774945d
......@@ -133,12 +133,11 @@ class SAGE(LightningModule):
class DataModule(LightningDataModule):
def __init__(self, fanouts, batch_size, num_workers):
def __init__(self, dataset, fanouts, batch_size, num_workers):
super().__init__()
self.fanouts = fanouts
self.batch_size = batch_size
self.num_workers = num_workers
dataset = gb.BuiltinDataset("ogbn-products").load()
self.feature_store = dataset.feature
self.graph = dataset.graph
self.train_set = dataset.tasks[0].train_set
......@@ -209,8 +208,15 @@ if __name__ == "__main__":
)
args = parser.parse_args()
datamodule = DataModule([15, 10, 5], args.batch_size, args.num_workers)
model = SAGE(100, 256, datamodule.num_classes).to(torch.double)
dataset = gb.BuiltinDataset("ogbn-products").load()
datamodule = DataModule(
dataset,
[15, 10, 5],
args.batch_size,
args.num_workers,
)
in_size = dataset.feature.size("node", None, "feat")[0]
model = SAGE(in_size, 256, datamodule.num_classes).to(torch.double)
# Train.
checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max")
......
......@@ -631,11 +631,7 @@ def main(args):
args.dataset
)
# TODO: featch from ``feature store``.
if args.dataset == "ogbn-mag":
feat_size = 128
else:
feat_size = 768
feat_size = features.size("node", None, "feat")[0]
# As `ogb-lsc-mag240m` is a large dataset, features of `author` and
# `institution` are generated in advance and stored in the feature store.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment