Commit 3aaf0ca8 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update trainining code with new input from new multimer pipeline

parent 74670a88
......@@ -273,27 +273,29 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
return self.model(batch)
def training_step(self, batch, batch_idx):
features,gt_features = batch
# Log it
if(self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device)
if(self.ema.device != features["aatype"].device):
self.ema.to(features["aatype"].device)
# Run the model
outputs = self(batch)
outputs = self(features)
# Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch)
features = tensor_tree_map(lambda t: t[..., -1], features)
# Compute loss
loss, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
outputs, (features,gt_features), _return_breakdown=True
)
# Log it
self._log(loss_breakdown, batch, outputs)
self._log(loss_breakdown, features, outputs)
return loss
def validation_step(self, batch, batch_idx):
features,gt_features = batch
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather
......@@ -304,15 +306,15 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model
outputs = self(batch)
outputs = self(features)
# Compute loss and other metrics
batch["use_clamped_fape"] = 0.
features["use_clamped_fape"] = 0.
_, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
outputs, (features,gt_features), _return_breakdown=True
)
self._log(loss_breakdown, batch, outputs, train=False)
self._log(loss_breakdown, features, outputs, train=False)
def validation_epoch_end(self, _):
# Restore the model weights to normal
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment