Commit 98c1448b authored by Jacob Stevens's avatar Jacob Stevens Committed by Minjie Wang
Browse files

[Cleanup] Change Byte to Bool for training masks (#954)

* Change Byte to Bool for training masks

* Check if module has Bool, otherwise use Byte
parent 86cf154b
...@@ -25,9 +25,14 @@ def main(args): ...@@ -25,9 +25,14 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
......
...@@ -46,9 +46,14 @@ def main(args): ...@@ -46,9 +46,14 @@ def main(args):
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
else: else:
labels = torch.FloatTensor(data.labels) labels = torch.FloatTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask).type(torch.bool) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask).type(torch.bool) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask).type(torch.bool) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
......
...@@ -22,9 +22,14 @@ def main(args): ...@@ -22,9 +22,14 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
......
...@@ -41,9 +41,14 @@ def main(args): ...@@ -41,9 +41,14 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
num_feats = features.shape[1] num_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
......
...@@ -118,9 +118,14 @@ def main(args): ...@@ -118,9 +118,14 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
......
...@@ -25,9 +25,14 @@ def main(args): ...@@ -25,9 +25,14 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
......
...@@ -60,9 +60,14 @@ def main(args): ...@@ -60,9 +60,14 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
......
...@@ -45,9 +45,14 @@ def main(args): ...@@ -45,9 +45,14 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
...@@ -147,4 +152,4 @@ if __name__ == '__main__': ...@@ -147,4 +152,4 @@ if __name__ == '__main__':
help="graph self-loop (default=False)") help="graph self-loop (default=False)")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(args) main(args)
\ No newline at end of file
...@@ -24,9 +24,14 @@ def main(args): ...@@ -24,9 +24,14 @@ def main(args):
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
......
...@@ -25,9 +25,14 @@ def main(args): ...@@ -25,9 +25,14 @@ def main(args):
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
......
...@@ -149,9 +149,14 @@ def main(args): ...@@ -149,9 +149,14 @@ def main(args):
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
......
...@@ -120,9 +120,14 @@ def main(args): ...@@ -120,9 +120,14 @@ def main(args):
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
......
...@@ -30,9 +30,14 @@ def main(args): ...@@ -30,9 +30,14 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
......
...@@ -33,9 +33,14 @@ def main(args): ...@@ -33,9 +33,14 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
......
...@@ -23,9 +23,14 @@ def main(args): ...@@ -23,9 +23,14 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment