Unverified Commit 2e17db8a authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[ESM] fix `accelerate` tests for esmfold (#20387)

* fix `accelerate` tests for esmfold

* cleaner solution
parent d2357a01
......@@ -638,7 +638,7 @@ class EsmPreTrainedModel(PreTrainedModel):
config_class = EsmConfig
base_model_prefix = "esm"
_no_split_modules = ["EsmLayer"]
_no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock"]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
......
......@@ -1956,9 +1956,9 @@ class EsmFoldingTrunk(nn.Module):
for recycle_idx in range(no_recycles):
with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
# === Recycling ===
recycle_s = self.recycle_s_norm(recycle_s.detach())
recycle_z = self.recycle_z_norm(recycle_z.detach())
recycle_z += self.recycle_disto(recycle_bins.detach())
recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)
s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
......@@ -2207,6 +2207,9 @@ class EsmForProteinFolding(EsmPreTrainedModel):
return EsmForProteinFoldingOutput(**structure)
def af2_idx_to_esm_idx(self, aa, mask):
# avoid indexing on different devices
if self.af2_to_esm.device != aa.device:
self.af2_to_esm = self.af2_to_esm.to(aa.device)
aa = (aa + 1).masked_fill(mask != 1, 0)
return self.af2_to_esm[aa]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment