Commit b066dd55 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix no-DeepSpeed DDP

parent 8ffec72a
......@@ -64,6 +64,7 @@ c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float)
templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
......@@ -148,7 +149,7 @@ config = mlc.ConfigDict(
"same_prob": 0.1,
"uniform_prob": 0.1,
},
"max_extra_msa": 5120,
"max_extra_msa": 1024,
"max_recycling_iters": 3,
"msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False,
......@@ -174,7 +175,6 @@ config = mlc.ConfigDict(
},
"supervised": {
"clamp_prob": 0.9,
"uniform_recycling": True,
"supervised_features": [
"all_atom_mask",
"all_atom_positions",
......@@ -194,6 +194,7 @@ config = mlc.ConfigDict(
"crop_size": None,
"supervised": False,
"subsample_recycling": False,
"uniform_recycling": False,
},
"eval": {
"fixed_size": True,
......@@ -206,27 +207,29 @@ config = mlc.ConfigDict(
"crop_size": None,
"supervised": True,
"subsample_recycling": False,
"uniform_recycling": False,
},
"train": {
"fixed_size": True,
"subsample_templates": True,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512,
"max_msa_clusters": 128,
"max_template_hits": 4,
"max_templates": 4,
"shuffle_top_k_prefiltered": 20,
"crop": True,
"crop_size": 384,
"crop_size": 256,
"supervised": True,
"clamp_prob": 0.9,
"subsample_recycling": True,
"max_distillation_msa_clusters": 1000,
"uniform_recycling": True,
},
"data_module": {
"use_small_bfd": False,
"data_loaders": {
"batch_size": 1,
"num_workers": 2,
"num_workers": 4,
},
},
},
......@@ -374,7 +377,7 @@ config = mlc.ConfigDict(
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": False,
"enabled": tm_enabled,
},
"masked_msa": {
"c_m": c_m,
......@@ -452,6 +455,7 @@ config = mlc.ConfigDict(
"max_resolution": 3.0,
"eps": eps, # 1e-8,
"weight": 0.0,
"enabled": tm_enabled,
},
"eps": eps,
},
......
......@@ -43,8 +43,8 @@ def softmax_cross_entropy(logits, labels):
def sigmoid_cross_entropy(logits, labels):
log_p = torch.nn.functional.logsigmoid(logits)
log_not_p = torch.nn.functional.logsigmoid(-logits)
log_p = torch.log(torch.sigmoid(logits))
log_not_p = torch.log(torch.sigmoid(-logits))
loss = -labels * log_p - (1 - labels) * log_not_p
return loss
......@@ -1462,7 +1462,7 @@ class AlphaFoldLoss(nn.Module):
self.config = config
def forward(self, out, batch):
if "violation" not in out.keys() and self.config.violation.weight:
if "violation" not in out.keys():
out["violation"] = find_structural_violations(
batch,
out["sm"]["positions"][-1],
......@@ -1509,20 +1509,21 @@ class AlphaFoldLoss(nn.Module):
out["violation"],
**batch,
),
"tm": lambda: tm_loss(
}
if(self.config.tm.enabled):
loss_fns["tm"] = lambda: tm_loss(
logits=out["tm_logits"],
**{**batch, **out, **self.config.tm},
),
}
)
cum_loss = 0.
for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight
if weight:
loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss
loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss
return cum_loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment