"git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "f3e5700d49d5c7fe609aa16530b1b5d83ae10b90"
Commit 614e2763 authored by zhuww's avatar zhuww
Browse files

support running in multimer mode

parent 7e01f6d6
......@@ -575,7 +575,7 @@ multimer_model_config_update = {
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": tm_enabled,
"enabled": True,
},
"masked_msa": {
"c_m": c_m,
......
......@@ -881,18 +881,20 @@ def _process_single_hit(
) as e:
# These 3 errors indicate missing mmCIF experimental data rather than a
# problem with the template search, so turn them into warnings.
warning = (
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
"%s, mmCIF parsing errors: %s"
% (
hit_pdb_code,
hit_chain_id,
hit.sum_probs,
hit.index,
str(e),
parsing_result.errors,
)
)
# warning = (
# "%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
# "%s, mmCIF parsing errors: %s"
# % (
# hit_pdb_code,
# hit_chain_id,
# hit.sum_probs,
# hit.index,
# str(e),
# parsing_result.errors,
# )
# )
warning=None
if strict_error_check:
return SingleHitResult(features=None, error=warning, warning=None)
else:
......
......@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
from fastfold.model.fastnn import MSACore, OutProductMean, PairCore
from fastfold.model.fastnn.ops import Linear
from fastfold.distributed.comm import gather, scatter
from fastfold.distributed.comm import gather, scatter, col_to_row
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
from fastfold.utils.checkpointing import checkpoint_blocks
......@@ -49,6 +49,9 @@ class Evoformer(nn.Module):
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
if self.is_multimer:
m = scatter(m, dim=2)
else:
m = scatter(m, dim=1)
z = scatter(z, dim=1)
......@@ -76,6 +79,9 @@ class Evoformer(nn.Module):
m = m.squeeze(0)
z = z.squeeze(0)
if self.is_multimer:
m = gather(m, dim=1)
else:
m = gather(m, dim=0)
z = gather(z, dim=0)
......@@ -107,7 +113,9 @@ class Evoformer(nn.Module):
z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size))
torch.cuda.empty_cache()
if self.is_multimer:
m[0] = scatter(m[0], dim=2)
else:
m[0] = scatter(m[0], dim=1, drop_unused=True)
torch.cuda.empty_cache()
z[0] = scatter(z[0], dim=1, drop_unused=True)
......@@ -126,15 +134,8 @@ class Evoformer(nn.Module):
z = self.pair.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else:
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = z[0].clone()
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m[0] = self.msa_stack(m[0], z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m[0] = col_to_row(m[0])
m[0] = self.msa(m[0], z[0], msa_mask)
z = self.pair.inplace(z, pair_mask)
......@@ -143,6 +144,9 @@ class Evoformer(nn.Module):
z[0] = z[0].squeeze(0)
torch.cuda.empty_cache()
if self.is_multimer:
m[0] = gather(m[0], dim=1)
else:
m[0] = gather(m[0], dim=0, chunks=4)
z[0] = gather(z[0], dim=0, chunks=4)
......
......@@ -530,7 +530,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, C_s]
if self.is_multimer:
s = self.linear_out(
torch.cat((o, *o_pt, o_pt_norm, o_pair), dim=-1).to(dtype=z.dtype)
torch.cat((o, *o_pt, o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype)
)
else:
s = self.linear_out(
......@@ -874,7 +874,8 @@ class StructureModule(nn.Module):
def _forward_multimer(
self,
evoformer_output_dict,
s: torch.Tensor,
z: torch.Tensor,
aatype: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Dict[str, Any]:
......@@ -898,6 +899,7 @@ class StructureModule(nn.Module):
s.device,
)
outputs = []
z = [z]
for i in range(self.no_blocks):
# [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask)
......@@ -960,7 +962,7 @@ class StructureModule(nn.Module):
A dictionary of outputs
"""
if self.is_multimer:
outputs = self._forward_multimer(evoformer_output_dict, aatype, mask)
outputs = self._forward_multimer(evoformer_output_dict["single"], evoformer_output_dict["pair"], aatype, mask)
else:
outputs = self._forward_monomer(evoformer_output_dict, aatype, mask)
......
......@@ -127,10 +127,8 @@ def assign(translation_dict, orig_weights):
print(weights[0].shape)
raise
def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool = False):
data = np.load(npz_path)
# translations = get_translation_dict(model, is_multimer=("multimer" in version))
def get_translation_dict(model, version):
is_multimer = "multimer" in version
#######################
# Some templates
#######################
......@@ -540,8 +538,6 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool =
},
}
# return translations
no_templ = [
"model_3",
"model_4",
......@@ -557,11 +553,20 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool =
if "template_" in k:
evo_dict.pop(k)
if "_ptm" in version:
if "_ptm" in version or is_multimer:
translations["predicted_aligned_error_head"] = {
"logits": LinearParams(model.aux_heads.tm.linear)
}
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(model, version)
# Flatten keys and insert missing key prefixes
flat = _process_translations_dict(translations)
......@@ -578,5 +583,3 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool =
# Set weights
assign(flat, data)
......@@ -119,7 +119,8 @@ class FastFoldDataWorkFlow:
def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None:
storage_dir = "file:///tmp/ray/" + os.getlogin() + "/workflow_data"
timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
storage_dir = "file:///tmp/ray/" + str(timestamp) + "/workflow_data"
if storage_dir is not None:
if not os.path.exists(storage_dir):
os.makedirs(storage_dir[7:], exist_ok=True)
......
......@@ -137,7 +137,8 @@ class FastFoldMultimerDataWorkFlow:
def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None:
storage_dir = "file:///tmp/ray/" + os.getlogin() + "/workflow_data"
timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
storage_dir = "file:///tmp/ray/" + str(timestamp) + "/workflow_data"
if storage_dir is not None:
if not os.path.exists(storage_dir):
os.makedirs(storage_dir[7:], exist_ok=True)
......
......@@ -312,6 +312,7 @@ def inference_multimer_model(args):
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
if(args.relaxation):
amber_relaxer = relax.AmberRelaxation(
use_gpu=True,
**config.relax,
......@@ -328,6 +329,15 @@ def inference_multimer_model(args):
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if(args.save_outputs):
output_dict_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_output_dict.pkl'
)
with open(output_dict_path, "wb") as fp:
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Model output written to {output_dict_path}...")
def inference_monomer_model(args):
print("running in monomer mode...")
......@@ -454,6 +464,7 @@ def inference_monomer_model(args):
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
if(args.relaxation):
amber_relaxer = relax.AmberRelaxation(
use_gpu=True,
**config.relax,
......@@ -512,6 +523,9 @@ if __name__ == "__main__":
help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from
./data/params""")
parser.add_argument(
"--relaxation", action="store_false", default=False,
)
parser.add_argument(
"--save_outputs", action="store_true", default=False,
help="Whether to save all model outputs, including embeddings, etc."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment