Commit 98c1448b authored by Jacob Stevens's avatar Jacob Stevens Committed by Minjie Wang
Browse files

[Cleanup] Change Byte to Bool for training masks (#954)

* Change Byte to Bool for training masks

* Check if module has Bool, otherwise use Byte
parent 86cf154b
...@@ -25,6 +25,11 @@ def main(args): ...@@ -25,6 +25,11 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask) train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask) val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask) test_mask = torch.ByteTensor(data.test_mask)
......
...@@ -46,9 +46,14 @@ def main(args): ...@@ -46,9 +46,14 @@ def main(args):
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
else: else:
labels = torch.FloatTensor(data.labels) labels = torch.FloatTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask).type(torch.bool) if hasattr(torch, 'BoolTensor'):
val_mask = torch.ByteTensor(data.val_mask).type(torch.bool) train_mask = torch.BoolTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask).type(torch.bool) val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
......
...@@ -22,6 +22,11 @@ def main(args): ...@@ -22,6 +22,11 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask) train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask) val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask) test_mask = torch.ByteTensor(data.test_mask)
......
...@@ -41,6 +41,11 @@ def main(args): ...@@ -41,6 +41,11 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask) train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask) val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask) test_mask = torch.ByteTensor(data.test_mask)
......
...@@ -118,6 +118,11 @@ def main(args): ...@@ -118,6 +118,11 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask) train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask) val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask) test_mask = torch.ByteTensor(data.test_mask)
......
...@@ -25,6 +25,11 @@ def main(args): ...@@ -25,6 +25,11 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask) train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask) val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask) test_mask = torch.ByteTensor(data.test_mask)
......
...@@ -60,6 +60,11 @@ def main(args): ...@@ -60,6 +60,11 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask) train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask) val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask) test_mask = torch.ByteTensor(data.test_mask)
......
...@@ -45,6 +45,11 @@ def main(args): ...@@ -45,6 +45,11 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask) train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask) val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask) test_mask = torch.ByteTensor(data.test_mask)
......
...@@ -24,6 +24,11 @@ def main(args): ...@@ -24,6 +24,11 @@ def main(args):
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask) train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask) val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask) test_mask = torch.ByteTensor(data.test_mask)
......
...@@ -25,6 +25,11 @@ def main(args): ...@@ -25,6 +25,11 @@ def main(args):
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask) train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask) val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask) test_mask = torch.ByteTensor(data.test_mask)
......
...@@ -149,6 +149,11 @@ def main(args): ...@@ -149,6 +149,11 @@ def main(args):
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask) train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask) val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask) test_mask = torch.ByteTensor(data.test_mask)
......
...@@ -120,6 +120,11 @@ def main(args): ...@@ -120,6 +120,11 @@ def main(args):
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask) train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask) val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask) test_mask = torch.ByteTensor(data.test_mask)
......
...@@ -30,6 +30,11 @@ def main(args): ...@@ -30,6 +30,11 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask) train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask) val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask) test_mask = torch.ByteTensor(data.test_mask)
......
...@@ -33,6 +33,11 @@ def main(args): ...@@ -33,6 +33,11 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask) train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask) val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask) test_mask = torch.ByteTensor(data.test_mask)
......
...@@ -23,6 +23,11 @@ def main(args): ...@@ -23,6 +23,11 @@ def main(args):
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask) train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask) val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask) test_mask = torch.ByteTensor(data.test_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment