Commit 614e2763 authored by zhuww's avatar zhuww
Browse files

support running in multimer mode

parent 7e01f6d6
...@@ -575,7 +575,7 @@ multimer_model_config_update = { ...@@ -575,7 +575,7 @@ multimer_model_config_update = {
"tm": { "tm": {
"c_z": c_z, "c_z": c_z,
"no_bins": aux_distogram_bins, "no_bins": aux_distogram_bins,
"enabled": tm_enabled, "enabled": True,
}, },
"masked_msa": { "masked_msa": {
"c_m": c_m, "c_m": c_m,
......
...@@ -881,18 +881,20 @@ def _process_single_hit( ...@@ -881,18 +881,20 @@ def _process_single_hit(
) as e: ) as e:
# These 3 errors indicate missing mmCIF experimental data rather than a # These 3 errors indicate missing mmCIF experimental data rather than a
# problem with the template search, so turn them into warnings. # problem with the template search, so turn them into warnings.
warning = (
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: " # warning = (
"%s, mmCIF parsing errors: %s" # "%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
% ( # "%s, mmCIF parsing errors: %s"
hit_pdb_code, # % (
hit_chain_id, # hit_pdb_code,
hit.sum_probs, # hit_chain_id,
hit.index, # hit.sum_probs,
str(e), # hit.index,
parsing_result.errors, # str(e),
) # parsing_result.errors,
) # )
# )
warning=None
if strict_error_check: if strict_error_check:
return SingleHitResult(features=None, error=warning, warning=None) return SingleHitResult(features=None, error=warning, warning=None)
else: else:
......
...@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc ...@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
from fastfold.model.fastnn import MSACore, OutProductMean, PairCore from fastfold.model.fastnn import MSACore, OutProductMean, PairCore
from fastfold.model.fastnn.ops import Linear from fastfold.model.fastnn.ops import Linear
from fastfold.distributed.comm import gather, scatter from fastfold.distributed.comm import gather, scatter, col_to_row
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
from fastfold.utils.checkpointing import checkpoint_blocks from fastfold.utils.checkpointing import checkpoint_blocks
...@@ -49,7 +49,10 @@ class Evoformer(nn.Module): ...@@ -49,7 +49,10 @@ class Evoformer(nn.Module):
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size)) m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size)) z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
m = scatter(m, dim=1) if self.is_multimer:
m = scatter(m, dim=2)
else:
m = scatter(m, dim=1)
z = scatter(z, dim=1) z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0) msa_mask = msa_mask.unsqueeze(0)
...@@ -76,7 +79,10 @@ class Evoformer(nn.Module): ...@@ -76,7 +79,10 @@ class Evoformer(nn.Module):
m = m.squeeze(0) m = m.squeeze(0)
z = z.squeeze(0) z = z.squeeze(0)
m = gather(m, dim=0) if self.is_multimer:
m = gather(m, dim=1)
else:
m = gather(m, dim=0)
z = gather(z, dim=0) z = gather(z, dim=0)
m = m[:, :-padding_size, :] m = m[:, :-padding_size, :]
...@@ -107,8 +113,10 @@ class Evoformer(nn.Module): ...@@ -107,8 +113,10 @@ class Evoformer(nn.Module):
z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size)) z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size))
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.is_multimer:
m[0] = scatter(m[0], dim=1, drop_unused=True) m[0] = scatter(m[0], dim=2)
else:
m[0] = scatter(m[0], dim=1, drop_unused=True)
torch.cuda.empty_cache() torch.cuda.empty_cache()
z[0] = scatter(z[0], dim=1, drop_unused=True) z[0] = scatter(z[0], dim=1, drop_unused=True)
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -126,15 +134,8 @@ class Evoformer(nn.Module): ...@@ -126,15 +134,8 @@ class Evoformer(nn.Module):
z = self.pair.inplace(z, pair_mask) z = self.pair.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2) m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else: else:
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = z[0].clone()
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m[0] = self.msa_stack(m[0], z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z) z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2) m[0] = col_to_row(m[0])
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m[0] = self.msa(m[0], z[0], msa_mask) m[0] = self.msa(m[0], z[0], msa_mask)
z = self.pair.inplace(z, pair_mask) z = self.pair.inplace(z, pair_mask)
...@@ -143,7 +144,10 @@ class Evoformer(nn.Module): ...@@ -143,7 +144,10 @@ class Evoformer(nn.Module):
z[0] = z[0].squeeze(0) z[0] = z[0].squeeze(0)
torch.cuda.empty_cache() torch.cuda.empty_cache()
m[0] = gather(m[0], dim=0, chunks=4) if self.is_multimer:
m[0] = gather(m[0], dim=1)
else:
m[0] = gather(m[0], dim=0, chunks=4)
z[0] = gather(z[0], dim=0, chunks=4) z[0] = gather(z[0], dim=0, chunks=4)
m[0] = m[0][:, :-padding_size, :] m[0] = m[0][:, :-padding_size, :]
......
...@@ -360,7 +360,7 @@ class AlphaFold(nn.Module): ...@@ -360,7 +360,7 @@ class AlphaFold(nn.Module):
pair_mask=pair_mask.to(dtype=z[0].dtype), pair_mask=pair_mask.to(dtype=z[0].dtype),
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
)[0] )[0]
del extra_msa_feat, extra_msa_fn del extra_msa_feat, extra_msa_fn
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -530,7 +530,7 @@ class InvariantPointAttention(nn.Module): ...@@ -530,7 +530,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, C_s] # [*, N_res, C_s]
if self.is_multimer: if self.is_multimer:
s = self.linear_out( s = self.linear_out(
torch.cat((o, *o_pt, o_pt_norm, o_pair), dim=-1).to(dtype=z.dtype) torch.cat((o, *o_pt, o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype)
) )
else: else:
s = self.linear_out( s = self.linear_out(
...@@ -874,7 +874,8 @@ class StructureModule(nn.Module): ...@@ -874,7 +874,8 @@ class StructureModule(nn.Module):
def _forward_multimer( def _forward_multimer(
self, self,
evoformer_output_dict, s: torch.Tensor,
z: torch.Tensor,
aatype: torch.Tensor, aatype: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
...@@ -898,6 +899,7 @@ class StructureModule(nn.Module): ...@@ -898,6 +899,7 @@ class StructureModule(nn.Module):
s.device, s.device,
) )
outputs = [] outputs = []
z = [z]
for i in range(self.no_blocks): for i in range(self.no_blocks):
# [*, N, C_s] # [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask) s = s + self.ipa(s, z, rigids, mask)
...@@ -960,7 +962,7 @@ class StructureModule(nn.Module): ...@@ -960,7 +962,7 @@ class StructureModule(nn.Module):
A dictionary of outputs A dictionary of outputs
""" """
if self.is_multimer: if self.is_multimer:
outputs = self._forward_multimer(evoformer_output_dict, aatype, mask) outputs = self._forward_multimer(evoformer_output_dict["single"], evoformer_output_dict["pair"], aatype, mask)
else: else:
outputs = self._forward_monomer(evoformer_output_dict, aatype, mask) outputs = self._forward_monomer(evoformer_output_dict, aatype, mask)
......
...@@ -126,11 +126,9 @@ def assign(translation_dict, orig_weights): ...@@ -126,11 +126,9 @@ def assign(translation_dict, orig_weights):
print(ref[0].shape) print(ref[0].shape)
print(weights[0].shape) print(weights[0].shape)
raise raise
def get_translation_dict(model, version):
def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool = False): is_multimer = "multimer" in version
data = np.load(npz_path)
# translations = get_translation_dict(model, is_multimer=("multimer" in version))
####################### #######################
# Some templates # Some templates
####################### #######################
...@@ -540,16 +538,14 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool = ...@@ -540,16 +538,14 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool =
}, },
} }
# return translations
no_templ = [ no_templ = [
"model_3", "model_3",
"model_4", "model_4",
"model_5", "model_5",
"model_3_ptm", "model_3_ptm",
"model_4_ptm", "model_4_ptm",
"model_5_ptm", "model_5_ptm",
] ]
if version in no_templ: if version in no_templ:
evo_dict = translations["evoformer"] evo_dict = translations["evoformer"]
keys = list(evo_dict.keys()) keys = list(evo_dict.keys())
...@@ -557,10 +553,19 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool = ...@@ -557,10 +553,19 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool =
if "template_" in k: if "template_" in k:
evo_dict.pop(k) evo_dict.pop(k)
if "_ptm" in version: if "_ptm" in version or is_multimer:
translations["predicted_aligned_error_head"] = { translations["predicted_aligned_error_head"] = {
"logits": LinearParams(model.aux_heads.tm.linear) "logits": LinearParams(model.aux_heads.tm.linear)
} }
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(model, version)
# Flatten keys and insert missing key prefixes # Flatten keys and insert missing key prefixes
flat = _process_translations_dict(translations) flat = _process_translations_dict(translations)
...@@ -578,5 +583,3 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool = ...@@ -578,5 +583,3 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool =
# Set weights # Set weights
assign(flat, data) assign(flat, data)
...@@ -119,7 +119,8 @@ class FastFoldDataWorkFlow: ...@@ -119,7 +119,8 @@ class FastFoldDataWorkFlow:
def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None: def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None:
storage_dir = "file:///tmp/ray/" + os.getlogin() + "/workflow_data" timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
storage_dir = "file:///tmp/ray/" + str(timestamp) + "/workflow_data"
if storage_dir is not None: if storage_dir is not None:
if not os.path.exists(storage_dir): if not os.path.exists(storage_dir):
os.makedirs(storage_dir[7:], exist_ok=True) os.makedirs(storage_dir[7:], exist_ok=True)
......
...@@ -137,7 +137,8 @@ class FastFoldMultimerDataWorkFlow: ...@@ -137,7 +137,8 @@ class FastFoldMultimerDataWorkFlow:
def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None: def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None:
storage_dir = "file:///tmp/ray/" + os.getlogin() + "/workflow_data" timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
storage_dir = "file:///tmp/ray/" + str(timestamp) + "/workflow_data"
if storage_dir is not None: if storage_dir is not None:
if not os.path.exists(storage_dir): if not os.path.exists(storage_dir):
os.makedirs(storage_dir[7:], exist_ok=True) os.makedirs(storage_dir[7:], exist_ok=True)
......
...@@ -312,21 +312,31 @@ def inference_multimer_model(args): ...@@ -312,21 +312,31 @@ def inference_multimer_model(args):
with open(unrelaxed_output_path, 'w') as f: with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein)) f.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation( if(args.relaxation):
use_gpu=True, amber_relaxer = relax.AmberRelaxation(
**config.relax, use_gpu=True,
) **config.relax,
)
# Relax the prediction. # Relax the prediction.
t = time.perf_counter() t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}") print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if(args.save_outputs):
output_dict_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_output_dict.pkl'
)
with open(output_dict_path, "wb") as fp:
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
# Save the relaxed PDB. logger.info(f"Model output written to {output_dict_path}...")
relaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
def inference_monomer_model(args): def inference_monomer_model(args):
...@@ -454,21 +464,22 @@ def inference_monomer_model(args): ...@@ -454,21 +464,22 @@ def inference_monomer_model(args):
with open(unrelaxed_output_path, 'w') as f: with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein)) f.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation( if(args.relaxation):
use_gpu=True, amber_relaxer = relax.AmberRelaxation(
**config.relax, use_gpu=True,
) **config.relax,
)
# Relax the prediction.
t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB. # Relax the prediction.
relaxed_output_path = os.path.join(args.output_dir, t = time.perf_counter()
f'{tag}_{args.model_name}_relaxed.pdb') relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
with open(relaxed_output_path, 'w') as f: print(f"Relaxation time: {time.perf_counter() - t}")
f.write(relaxed_pdb_str)
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if(args.save_outputs): if(args.save_outputs):
output_dict_path = os.path.join( output_dict_path = os.path.join(
...@@ -512,6 +523,9 @@ if __name__ == "__main__": ...@@ -512,6 +523,9 @@ if __name__ == "__main__":
help="""Path to model parameters. If None, parameters are selected help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from automatically according to the model name from
./data/params""") ./data/params""")
parser.add_argument(
"--relaxation", action="store_false", default=False,
)
parser.add_argument( parser.add_argument(
"--save_outputs", action="store_true", default=False, "--save_outputs", action="store_true", default=False,
help="Whether to save all model outputs, including embeddings, etc." help="Whether to save all model outputs, including embeddings, etc."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment