Commit b066dd55 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix no-DeepSpeed DDP

parent 8ffec72a
...@@ -64,6 +64,7 @@ c_s = mlc.FieldReference(384, field_type=int) ...@@ -64,6 +64,7 @@ c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int) blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int) chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int) aux_distogram_bins = mlc.FieldReference(64, field_type=int)
tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float) eps = mlc.FieldReference(1e-8, field_type=float)
templates_enabled = mlc.FieldReference(True, field_type=bool) templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool) embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
...@@ -148,7 +149,7 @@ config = mlc.ConfigDict( ...@@ -148,7 +149,7 @@ config = mlc.ConfigDict(
"same_prob": 0.1, "same_prob": 0.1,
"uniform_prob": 0.1, "uniform_prob": 0.1,
}, },
"max_extra_msa": 5120, "max_extra_msa": 1024,
"max_recycling_iters": 3, "max_recycling_iters": 3,
"msa_cluster_features": True, "msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False, "reduce_msa_clusters_by_max_templates": False,
...@@ -174,7 +175,6 @@ config = mlc.ConfigDict( ...@@ -174,7 +175,6 @@ config = mlc.ConfigDict(
}, },
"supervised": { "supervised": {
"clamp_prob": 0.9, "clamp_prob": 0.9,
"uniform_recycling": True,
"supervised_features": [ "supervised_features": [
"all_atom_mask", "all_atom_mask",
"all_atom_positions", "all_atom_positions",
...@@ -194,6 +194,7 @@ config = mlc.ConfigDict( ...@@ -194,6 +194,7 @@ config = mlc.ConfigDict(
"crop_size": None, "crop_size": None,
"supervised": False, "supervised": False,
"subsample_recycling": False, "subsample_recycling": False,
"uniform_recycling": False,
}, },
"eval": { "eval": {
"fixed_size": True, "fixed_size": True,
...@@ -206,27 +207,29 @@ config = mlc.ConfigDict( ...@@ -206,27 +207,29 @@ config = mlc.ConfigDict(
"crop_size": None, "crop_size": None,
"supervised": True, "supervised": True,
"subsample_recycling": False, "subsample_recycling": False,
"uniform_recycling": False,
}, },
"train": { "train": {
"fixed_size": True, "fixed_size": True,
"subsample_templates": True, "subsample_templates": True,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512, "max_msa_clusters": 128,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"shuffle_top_k_prefiltered": 20, "shuffle_top_k_prefiltered": 20,
"crop": True, "crop": True,
"crop_size": 384, "crop_size": 256,
"supervised": True, "supervised": True,
"clamp_prob": 0.9, "clamp_prob": 0.9,
"subsample_recycling": True, "subsample_recycling": True,
"max_distillation_msa_clusters": 1000, "max_distillation_msa_clusters": 1000,
"uniform_recycling": True,
}, },
"data_module": { "data_module": {
"use_small_bfd": False, "use_small_bfd": False,
"data_loaders": { "data_loaders": {
"batch_size": 1, "batch_size": 1,
"num_workers": 2, "num_workers": 4,
}, },
}, },
}, },
...@@ -374,7 +377,7 @@ config = mlc.ConfigDict( ...@@ -374,7 +377,7 @@ config = mlc.ConfigDict(
"tm": { "tm": {
"c_z": c_z, "c_z": c_z,
"no_bins": aux_distogram_bins, "no_bins": aux_distogram_bins,
"enabled": False, "enabled": tm_enabled,
}, },
"masked_msa": { "masked_msa": {
"c_m": c_m, "c_m": c_m,
...@@ -452,6 +455,7 @@ config = mlc.ConfigDict( ...@@ -452,6 +455,7 @@ config = mlc.ConfigDict(
"max_resolution": 3.0, "max_resolution": 3.0,
"eps": eps, # 1e-8, "eps": eps, # 1e-8,
"weight": 0.0, "weight": 0.0,
"enabled": tm_enabled,
}, },
"eps": eps, "eps": eps,
}, },
......
...@@ -43,8 +43,8 @@ def softmax_cross_entropy(logits, labels): ...@@ -43,8 +43,8 @@ def softmax_cross_entropy(logits, labels):
def sigmoid_cross_entropy(logits, labels): def sigmoid_cross_entropy(logits, labels):
log_p = torch.nn.functional.logsigmoid(logits) log_p = torch.log(torch.sigmoid(logits))
log_not_p = torch.nn.functional.logsigmoid(-logits) log_not_p = torch.log(torch.sigmoid(-logits))
loss = -labels * log_p - (1 - labels) * log_not_p loss = -labels * log_p - (1 - labels) * log_not_p
return loss return loss
...@@ -1462,7 +1462,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1462,7 +1462,7 @@ class AlphaFoldLoss(nn.Module):
self.config = config self.config = config
def forward(self, out, batch): def forward(self, out, batch):
if "violation" not in out.keys() and self.config.violation.weight: if "violation" not in out.keys():
out["violation"] = find_structural_violations( out["violation"] = find_structural_violations(
batch, batch,
out["sm"]["positions"][-1], out["sm"]["positions"][-1],
...@@ -1509,16 +1509,17 @@ class AlphaFoldLoss(nn.Module): ...@@ -1509,16 +1509,17 @@ class AlphaFoldLoss(nn.Module):
out["violation"], out["violation"],
**batch, **batch,
), ),
"tm": lambda: tm_loss( }
if(self.config.tm.enabled):
loss_fns["tm"] = lambda: tm_loss(
logits=out["tm_logits"], logits=out["tm_logits"],
**{**batch, **out, **self.config.tm}, **{**batch, **out, **self.config.tm},
), )
}
cum_loss = 0. cum_loss = 0.
for loss_name, loss_fn in loss_fns.items(): for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight weight = self.config[loss_name].weight
if weight:
loss = loss_fn() loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)): if(torch.isnan(loss) or torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...") logging.warning(f"{loss_name} loss is NaN. Skipping...")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment