Commit d5480cbe authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix horrible in-place operation bug

parent 2faff451
......@@ -23,7 +23,7 @@
"opt_level": "O2"
},
"zero_optimization": {
"stage": 1
"stage": 2
},
"activation_checkpointing": {
"partition_activations": true,
......
......@@ -44,7 +44,6 @@ class TemplatePointwiseAttention(nn.Module):
"""
Implements Algorithm 17.
"""
def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs):
"""
Args:
......@@ -85,9 +84,6 @@ class TemplatePointwiseAttention(nn.Module):
[*, N_res, N_res, C_z] pair embedding update
"""
if template_mask is None:
# NOTE: This is not the "template_mask" from the supplement, but a
# [*, N_templ] mask from the code. I'm pretty sure it's always just
# 1, but not sure enough to remove it. It's nice to have, I guess.
template_mask = t.new_ones(t.shape[:-3])
bias = self.inf * (template_mask[..., None, None, None, None, :] - 1)
......@@ -174,11 +170,16 @@ class TemplatePairStackBlock(nn.Module):
)
def forward(self, z, mask, chunk_size, _mask_trans=True):
for templ_idx in range(z.shape[-4]):
# Select a single template at a time
single = z[..., templ_idx:templ_idx+1, :, :, :]
single_mask = mask[..., templ_idx:templ_idx+1, :, :]
single_templates = [
t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)
]
single_templates_masks = [
m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)
]
for i in range(len(single_templates)):
single = single_templates[i]
single_mask = single_templates_masks[i]
single = single + self.dropout_row(
self.tri_att_start(
single,
......@@ -210,7 +211,10 @@ class TemplatePairStackBlock(nn.Module):
chunk_size=chunk_size,
mask=single_mask if _mask_trans else None
)
z[..., templ_idx:templ_idx+1, :, :, :] = single
single_templates[i] = single
z = torch.cat(single_templates, dim=-4)
return z
......@@ -219,7 +223,6 @@ class TemplatePairStack(nn.Module):
"""
Implements Algorithm 16.
"""
def __init__(
self,
c_t,
......
......@@ -61,6 +61,10 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
loss = self.loss(outputs, batch)
if(torch.isnan(loss) or torch.isinf(loss)):
loss = None
logging.warning("loss is NaN. Skipping example...")
return {"loss": loss}
def validation_step(self, batch, batch_idx):
......@@ -113,12 +117,12 @@ def main(args):
sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd)
logging.info("Successfully loaded model weights...")
#data_module = DummyDataLoader("batch.pickle")
data_module = OpenFoldDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args)
)
data_module = DummyDataLoader("batch.pickle")
#data_module = OpenFoldDataModule(
# config=config.data,
# batch_seed=args.seed,
# **vars(args)
#)
data_module.prepare_data()
data_module.setup()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment