Commit b47138dc authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix TorchScripting bug, add command-line parameter for scripting

parent e1142cf3
......@@ -218,6 +218,7 @@ class Attention(nn.Module):
self.c_hidden * self.no_heads, self.c_q, init="final"
)
self.linear_g = None
if self.gating:
self.linear_g = Linear(
self.c_q, self.c_hidden * self.no_heads, init="gating"
......@@ -282,7 +283,7 @@ class Attention(nn.Module):
# [*, Q, H, C_hidden]
o = o.transpose(-2, -3)
if self.gating:
if(self.linear_g is not None):
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
......
......@@ -2,7 +2,7 @@ import argparse
import logging
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "5"
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
......@@ -13,7 +13,7 @@ import time
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.plugins.environments import SLURMEnvironment
import torch
......@@ -120,7 +120,8 @@ def main(args):
logging.info("Successfully loaded model weights...")
# TorchScript components of the model
script_preset_(model_module)
if(args.script_modules):
script_preset_(model_module)
#data_module = DummyDataLoader("batch.pickle")
data_module = OpenFoldDataModule(
......@@ -128,10 +129,10 @@ def main(args):
batch_seed=args.seed,
**vars(args)
)
data_module.prepare_data()
data_module.setup()
callbacks = []
if(args.checkpoint_best_val):
checkpoint_dir = os.path.join(args.output_dir, "checkpoints")
......@@ -172,7 +173,7 @@ def main(args):
cluster_environment=cluster_environment,
)
elif (args.gpus is not None and args.gpus) > 1 or args.num_nodes > 1:
strategy = "ddp"
strategy = DDPPlugin(find_unused_parameters=False)
else:
strategy = None
......@@ -306,9 +307,13 @@ if __name__ == "__main__":
help="Whether to load just model weights as opposed to training state"
)
parser.add_argument(
"--log_performance", action='store_true',
"--log_performance", type=bool_type, default=False,
help="Measure performance"
)
parser.add_argument(
"--script_modules", type=bool_type, default=False,
help="Whether to TorchScript eligible components of them model"
)
parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment