"demo/java/src/ImageRestorer.java" did not exist on "a64900f3382e8e32155c47b8c5597022480b20ac"
Commit 3aaf0ca8 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update trainining code with new input from new multimer pipeline

parent 74670a88
...@@ -273,27 +273,29 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper): ...@@ -273,27 +273,29 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
return self.model(batch) return self.model(batch)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
features,gt_features = batch
# Log it # Log it
if(self.ema.device != batch["aatype"].device): if(self.ema.device != features["aatype"].device):
self.ema.to(batch["aatype"].device) self.ema.to(features["aatype"].device)
# Run the model # Run the model
outputs = self(batch) outputs = self(features)
# Remove the recycling dimension # Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch) features = tensor_tree_map(lambda t: t[..., -1], features)
# Compute loss # Compute loss
loss, loss_breakdown = self.loss( loss, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True outputs, (features,gt_features), _return_breakdown=True
) )
# Log it # Log it
self._log(loss_breakdown, batch, outputs) self._log(loss_breakdown, features, outputs)
return loss return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
features,gt_features = batch
# At the start of validation, load the EMA weights # At the start of validation, load the EMA weights
if(self.cached_weights is None): if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather # model.state_dict() contains references to model weights rather
...@@ -304,15 +306,15 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper): ...@@ -304,15 +306,15 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
self.model.load_state_dict(self.ema.state_dict()["params"]) self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model # Run the model
outputs = self(batch) outputs = self(features)
# Compute loss and other metrics # Compute loss and other metrics
batch["use_clamped_fape"] = 0. features["use_clamped_fape"] = 0.
_, loss_breakdown = self.loss( _, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True outputs, (features,gt_features), _return_breakdown=True
) )
self._log(loss_breakdown, batch, outputs, train=False) self._log(loss_breakdown, features, outputs, train=False)
def validation_epoch_end(self, _): def validation_epoch_end(self, _):
# Restore the model weights to normal # Restore the model weights to normal
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment