Unverified Commit 15d05be3 authored by Zhen Liu's avatar Zhen Liu Committed by GitHub
Browse files

Fix num_labels to num_classes in dataset files (#6666)

parent 5e64481b
...@@ -93,7 +93,7 @@ def main(args): ...@@ -93,7 +93,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
......
...@@ -69,7 +69,7 @@ def main(args): ...@@ -69,7 +69,7 @@ def main(args):
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
test_mask = mx.nd.array(np.nonzero(test_mask.asnumpy())[0], ctx=ctx) test_mask = mx.nd.array(np.nonzero(test_mask.asnumpy())[0], ctx=ctx)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
g = dgl.remove_self_loop(g) g = dgl.remove_self_loop(g)
......
...@@ -94,7 +94,7 @@ def main(args): ...@@ -94,7 +94,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
......
...@@ -46,7 +46,7 @@ def main(args): ...@@ -46,7 +46,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
......
...@@ -111,7 +111,7 @@ def main(args): ...@@ -111,7 +111,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
......
...@@ -93,7 +93,7 @@ def main(args): ...@@ -93,7 +93,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
......
...@@ -56,7 +56,7 @@ def main(args): ...@@ -56,7 +56,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
......
...@@ -48,7 +48,7 @@ def main(args): ...@@ -48,7 +48,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
......
...@@ -86,7 +86,7 @@ class QM9(QM9Dataset): ...@@ -86,7 +86,7 @@ class QM9(QM9Dataset):
Examples Examples
-------- --------
>>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0) >>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)
>>> data.num_labels >>> data.num_classes
2 2
>>> >>>
>>> # iterate over the dataset >>> # iterate over the dataset
......
...@@ -116,7 +116,7 @@ if __name__ == "__main__": ...@@ -116,7 +116,7 @@ if __name__ == "__main__":
# create GAT model # create GAT model
in_size = features.shape[1] in_size = features.shape[1]
out_size = train_dataset.num_labels out_size = train_dataset.num_classes
model = GAT(in_size, 256, out_size, heads=[4, 4, 6]).to(device) model = GAT(in_size, 256, out_size, heads=[4, 4, 6]).to(device)
# model training # model training
......
...@@ -49,7 +49,7 @@ def main(args): ...@@ -49,7 +49,7 @@ def main(args):
else: else:
device = "cpu" device = "cpu"
num_classes = train_dataset.num_labels num_classes = train_dataset.num_classes
# Extract node features # Extract node features
graph = train_dataset[0] graph = train_dataset[0]
......
...@@ -59,7 +59,7 @@ def main(args): ...@@ -59,7 +59,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = g.num_edges() n_edges = g.num_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
......
...@@ -51,7 +51,7 @@ def main(args): ...@@ -51,7 +51,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = g.num_edges() n_edges = g.num_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
......
...@@ -28,7 +28,7 @@ def load_dataset(name): ...@@ -28,7 +28,7 @@ def load_dataset(name):
data = CitationGraphDataset("cora") data = CitationGraphDataset("cora")
g = data[0] g = data[0]
n_classes = data.num_labels n_classes = data.num_classes
train_mask = g.ndata["train_mask"] train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
......
...@@ -38,7 +38,7 @@ def main(args): ...@@ -38,7 +38,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = g.num_edges() n_edges = g.num_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
......
...@@ -50,7 +50,7 @@ def main(args): ...@@ -50,7 +50,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
# add self loop # add self loop
......
...@@ -66,7 +66,7 @@ def main(args): ...@@ -66,7 +66,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
num_feats = features.shape[1] num_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
......
...@@ -114,7 +114,7 @@ def main(args): ...@@ -114,7 +114,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
......
...@@ -121,7 +121,7 @@ def main(args): ...@@ -121,7 +121,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
......
...@@ -43,7 +43,7 @@ def main(args): ...@@ -43,7 +43,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment