Unverified Commit 04ed6126 authored by Christian Clauss's avatar Christian Clauss Committed by GitHub
Browse files

[Fix] Use ==/!= to compare constant literals (str, bytes, int, float, tuple) (#3415)

* Use ==/!= to compare constant literals (str, bytes, int, float, tuple)

Avoid Syntax Warnings on Python >= 3.8

$ `python3`
```
>>> "" == ""
True
>>> "" is ""
<stdin>:1: SyntaxWarning: "is" with a literal. Did you mean "=="?
True
```

* Use ==/!= to compare constant literals (str, bytes, int, float, tuple)
parent b81efb2b
...@@ -232,12 +232,12 @@ def compute_perm(parents): ...@@ -232,12 +232,12 @@ def compute_perm(parents):
assert 0 <= len(indices_node) <= 2 assert 0 <= len(indices_node) <= 2
# Add a node to go with a singelton. # Add a node to go with a singelton.
if len(indices_node) is 1: if len(indices_node) == 1:
indices_node.append(pool_singeltons) indices_node.append(pool_singeltons)
pool_singeltons += 1 pool_singeltons += 1
# Add two nodes as children of a singelton in the parent. # Add two nodes as children of a singelton in the parent.
elif len(indices_node) is 0: elif len(indices_node) == 0:
indices_node.append(pool_singeltons + 0) indices_node.append(pool_singeltons + 0)
indices_node.append(pool_singeltons + 1) indices_node.append(pool_singeltons + 1)
pool_singeltons += 2 pool_singeltons += 2
......
...@@ -158,11 +158,11 @@ def main(): ...@@ -158,11 +158,11 @@ def main():
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, collate_fn=collate_dgl) num_workers=args.num_workers, collate_fn=collate_dgl)
if args.save_test_dir is not '': if args.save_test_dir != '':
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, collate_fn=collate_dgl) num_workers=args.num_workers, collate_fn=collate_dgl)
if args.checkpoint_dir is not '': if args.checkpoint_dir != '':
os.makedirs(args.checkpoint_dir, exist_ok=True) os.makedirs(args.checkpoint_dir, exist_ok=True)
shared_params = { shared_params = {
...@@ -188,7 +188,7 @@ def main(): ...@@ -188,7 +188,7 @@ def main():
optimizer = optim.Adam(model.parameters(), lr=0.001) optimizer = optim.Adam(model.parameters(), lr=0.001)
if args.log_dir is not '': if args.log_dir != '':
writer = SummaryWriter(log_dir=args.log_dir) writer = SummaryWriter(log_dir=args.log_dir)
best_valid_mae = 1000 best_valid_mae = 1000
...@@ -209,13 +209,13 @@ def main(): ...@@ -209,13 +209,13 @@ def main():
print({'Train': train_mae, 'Validation': valid_mae}) print({'Train': train_mae, 'Validation': valid_mae})
if args.log_dir is not '': if args.log_dir != '':
writer.add_scalar('valid/mae', valid_mae, epoch) writer.add_scalar('valid/mae', valid_mae, epoch)
writer.add_scalar('train/mae', train_mae, epoch) writer.add_scalar('train/mae', train_mae, epoch)
if valid_mae < best_valid_mae: if valid_mae < best_valid_mae:
best_valid_mae = valid_mae best_valid_mae = valid_mae
if args.checkpoint_dir is not '': if args.checkpoint_dir != '':
print('Saving checkpoint...') print('Saving checkpoint...')
checkpoint = {'epoch': epoch, 'model_state_dict': model.state_dict(), checkpoint = {'epoch': epoch, 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
...@@ -223,7 +223,7 @@ def main(): ...@@ -223,7 +223,7 @@ def main():
'num_params': num_params} 'num_params': num_params}
torch.save(checkpoint, os.path.join(args.checkpoint_dir, 'checkpoint.pt')) torch.save(checkpoint, os.path.join(args.checkpoint_dir, 'checkpoint.pt'))
if args.save_test_dir is not '': if args.save_test_dir != '':
print('Predicting on test data...') print('Predicting on test data...')
y_pred = test(model, device, test_loader) y_pred = test(model, device, test_loader)
print('Saving test submission file...') print('Saving test submission file...')
...@@ -233,7 +233,7 @@ def main(): ...@@ -233,7 +233,7 @@ def main():
print(f'Best validation MAE so far: {best_valid_mae}') print(f'Best validation MAE so far: {best_valid_mae}')
if args.log_dir is not '': if args.log_dir != '':
writer.close() writer.close()
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment