Commit d5480cbe authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix horrible in-place operation bug

parent 2faff451
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
"opt_level": "O2" "opt_level": "O2"
}, },
"zero_optimization": { "zero_optimization": {
"stage": 1 "stage": 2
}, },
"activation_checkpointing": { "activation_checkpointing": {
"partition_activations": true, "partition_activations": true,
......
...@@ -44,7 +44,6 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -44,7 +44,6 @@ class TemplatePointwiseAttention(nn.Module):
""" """
Implements Algorithm 17. Implements Algorithm 17.
""" """
def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs): def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs):
""" """
Args: Args:
...@@ -85,9 +84,6 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -85,9 +84,6 @@ class TemplatePointwiseAttention(nn.Module):
[*, N_res, N_res, C_z] pair embedding update [*, N_res, N_res, C_z] pair embedding update
""" """
if template_mask is None: if template_mask is None:
# NOTE: This is not the "template_mask" from the supplement, but a
# [*, N_templ] mask from the code. I'm pretty sure it's always just
# 1, but not sure enough to remove it. It's nice to have, I guess.
template_mask = t.new_ones(t.shape[:-3]) template_mask = t.new_ones(t.shape[:-3])
bias = self.inf * (template_mask[..., None, None, None, None, :] - 1) bias = self.inf * (template_mask[..., None, None, None, None, :] - 1)
...@@ -174,10 +170,15 @@ class TemplatePairStackBlock(nn.Module): ...@@ -174,10 +170,15 @@ class TemplatePairStackBlock(nn.Module):
) )
def forward(self, z, mask, chunk_size, _mask_trans=True): def forward(self, z, mask, chunk_size, _mask_trans=True):
for templ_idx in range(z.shape[-4]): single_templates = [
# Select a single template at a time t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)
single = z[..., templ_idx:templ_idx+1, :, :, :] ]
single_mask = mask[..., templ_idx:templ_idx+1, :, :] single_templates_masks = [
m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)
]
for i in range(len(single_templates)):
single = single_templates[i]
single_mask = single_templates_masks[i]
single = single + self.dropout_row( single = single + self.dropout_row(
self.tri_att_start( self.tri_att_start(
...@@ -210,7 +211,10 @@ class TemplatePairStackBlock(nn.Module): ...@@ -210,7 +211,10 @@ class TemplatePairStackBlock(nn.Module):
chunk_size=chunk_size, chunk_size=chunk_size,
mask=single_mask if _mask_trans else None mask=single_mask if _mask_trans else None
) )
z[..., templ_idx:templ_idx+1, :, :, :] = single
single_templates[i] = single
z = torch.cat(single_templates, dim=-4)
return z return z
...@@ -219,7 +223,6 @@ class TemplatePairStack(nn.Module): ...@@ -219,7 +223,6 @@ class TemplatePairStack(nn.Module):
""" """
Implements Algorithm 16. Implements Algorithm 16.
""" """
def __init__( def __init__(
self, self,
c_t, c_t,
......
...@@ -61,6 +61,10 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -61,6 +61,10 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss # Compute loss
loss = self.loss(outputs, batch) loss = self.loss(outputs, batch)
if(torch.isnan(loss) or torch.isinf(loss)):
loss = None
logging.warning("loss is NaN. Skipping example...")
return {"loss": loss} return {"loss": loss}
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
...@@ -113,12 +117,12 @@ def main(args): ...@@ -113,12 +117,12 @@ def main(args):
sd = {k[len("module."):]:v for k,v in sd.items()} sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd) model_module.load_state_dict(sd)
logging.info("Successfully loaded model weights...") logging.info("Successfully loaded model weights...")
#data_module = DummyDataLoader("batch.pickle") data_module = DummyDataLoader("batch.pickle")
data_module = OpenFoldDataModule( #data_module = OpenFoldDataModule(
config=config.data, # config=config.data,
batch_seed=args.seed, # batch_seed=args.seed,
**vars(args) # **vars(args)
) #)
data_module.prepare_data() data_module.prepare_data()
data_module.setup() data_module.setup()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment