Unverified Commit 2e17db8a authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[ESM] fix `accelerate` tests for esmfold (#20387)

* fix `accelerate` tests for esmfold

* cleaner solution
parent d2357a01
...@@ -638,7 +638,7 @@ class EsmPreTrainedModel(PreTrainedModel): ...@@ -638,7 +638,7 @@ class EsmPreTrainedModel(PreTrainedModel):
config_class = EsmConfig config_class = EsmConfig
base_model_prefix = "esm" base_model_prefix = "esm"
_no_split_modules = ["EsmLayer"] _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock"]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -1956,9 +1956,9 @@ class EsmFoldingTrunk(nn.Module): ...@@ -1956,9 +1956,9 @@ class EsmFoldingTrunk(nn.Module):
for recycle_idx in range(no_recycles): for recycle_idx in range(no_recycles):
with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]): with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
# === Recycling === # === Recycling ===
recycle_s = self.recycle_s_norm(recycle_s.detach()) recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
recycle_z = self.recycle_z_norm(recycle_z.detach()) recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
recycle_z += self.recycle_disto(recycle_bins.detach()) recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)
s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask) s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
...@@ -2207,6 +2207,9 @@ class EsmForProteinFolding(EsmPreTrainedModel): ...@@ -2207,6 +2207,9 @@ class EsmForProteinFolding(EsmPreTrainedModel):
return EsmForProteinFoldingOutput(**structure) return EsmForProteinFoldingOutput(**structure)
def af2_idx_to_esm_idx(self, aa, mask): def af2_idx_to_esm_idx(self, aa, mask):
# avoid indexing on different devices
if self.af2_to_esm.device != aa.device:
self.af2_to_esm = self.af2_to_esm.to(aa.device)
aa = (aa + 1).masked_fill(mask != 1, 0) aa = (aa + 1).masked_fill(mask != 1, 0)
return self.af2_to_esm[aa] return self.af2_to_esm[aa]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment