Commit 614e2763 authored by zhuww's avatar zhuww
Browse files

support running in multimer mode

parent 7e01f6d6
...@@ -575,7 +575,7 @@ multimer_model_config_update = { ...@@ -575,7 +575,7 @@ multimer_model_config_update = {
"tm": { "tm": {
"c_z": c_z, "c_z": c_z,
"no_bins": aux_distogram_bins, "no_bins": aux_distogram_bins,
"enabled": tm_enabled, "enabled": True,
}, },
"masked_msa": { "masked_msa": {
"c_m": c_m, "c_m": c_m,
......
...@@ -881,18 +881,20 @@ def _process_single_hit( ...@@ -881,18 +881,20 @@ def _process_single_hit(
) as e: ) as e:
# These 3 errors indicate missing mmCIF experimental data rather than a # These 3 errors indicate missing mmCIF experimental data rather than a
# problem with the template search, so turn them into warnings. # problem with the template search, so turn them into warnings.
warning = (
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: " # warning = (
"%s, mmCIF parsing errors: %s" # "%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
% ( # "%s, mmCIF parsing errors: %s"
hit_pdb_code, # % (
hit_chain_id, # hit_pdb_code,
hit.sum_probs, # hit_chain_id,
hit.index, # hit.sum_probs,
str(e), # hit.index,
parsing_result.errors, # str(e),
) # parsing_result.errors,
) # )
# )
warning=None
if strict_error_check: if strict_error_check:
return SingleHitResult(features=None, error=warning, warning=None) return SingleHitResult(features=None, error=warning, warning=None)
else: else:
......
...@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc ...@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
from fastfold.model.fastnn import MSACore, OutProductMean, PairCore from fastfold.model.fastnn import MSACore, OutProductMean, PairCore
from fastfold.model.fastnn.ops import Linear from fastfold.model.fastnn.ops import Linear
from fastfold.distributed.comm import gather, scatter from fastfold.distributed.comm import gather, scatter, col_to_row
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
from fastfold.utils.checkpointing import checkpoint_blocks from fastfold.utils.checkpointing import checkpoint_blocks
...@@ -49,6 +49,9 @@ class Evoformer(nn.Module): ...@@ -49,6 +49,9 @@ class Evoformer(nn.Module):
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size)) m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size)) z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
if self.is_multimer:
m = scatter(m, dim=2)
else:
m = scatter(m, dim=1) m = scatter(m, dim=1)
z = scatter(z, dim=1) z = scatter(z, dim=1)
...@@ -76,6 +79,9 @@ class Evoformer(nn.Module): ...@@ -76,6 +79,9 @@ class Evoformer(nn.Module):
m = m.squeeze(0) m = m.squeeze(0)
z = z.squeeze(0) z = z.squeeze(0)
if self.is_multimer:
m = gather(m, dim=1)
else:
m = gather(m, dim=0) m = gather(m, dim=0)
z = gather(z, dim=0) z = gather(z, dim=0)
...@@ -107,7 +113,9 @@ class Evoformer(nn.Module): ...@@ -107,7 +113,9 @@ class Evoformer(nn.Module):
z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size)) z[0] = torch.nn.functional.pad(z[0], (0, 0, 0, padding_size, 0, padding_size))
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.is_multimer:
m[0] = scatter(m[0], dim=2)
else:
m[0] = scatter(m[0], dim=1, drop_unused=True) m[0] = scatter(m[0], dim=1, drop_unused=True)
torch.cuda.empty_cache() torch.cuda.empty_cache()
z[0] = scatter(z[0], dim=1, drop_unused=True) z[0] = scatter(z[0], dim=1, drop_unused=True)
...@@ -126,15 +134,8 @@ class Evoformer(nn.Module): ...@@ -126,15 +134,8 @@ class Evoformer(nn.Module):
z = self.pair.inplace(z, pair_mask) z = self.pair.inplace(z, pair_mask)
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2) m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
else: else:
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = z[0].clone()
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m[0] = self.msa_stack(m[0], z_ori, msa_mask)
z = self.communication.inplace(m[0], msa_mask, z) z = self.communication.inplace(m[0], msa_mask, z)
m[0], work = All_to_All_Async.apply(m[0], 1, 2) m[0] = col_to_row(m[0])
m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
m[0] = self.msa(m[0], z[0], msa_mask) m[0] = self.msa(m[0], z[0], msa_mask)
z = self.pair.inplace(z, pair_mask) z = self.pair.inplace(z, pair_mask)
...@@ -143,6 +144,9 @@ class Evoformer(nn.Module): ...@@ -143,6 +144,9 @@ class Evoformer(nn.Module):
z[0] = z[0].squeeze(0) z[0] = z[0].squeeze(0)
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.is_multimer:
m[0] = gather(m[0], dim=1)
else:
m[0] = gather(m[0], dim=0, chunks=4) m[0] = gather(m[0], dim=0, chunks=4)
z[0] = gather(z[0], dim=0, chunks=4) z[0] = gather(z[0], dim=0, chunks=4)
......
...@@ -530,7 +530,7 @@ class InvariantPointAttention(nn.Module): ...@@ -530,7 +530,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, C_s] # [*, N_res, C_s]
if self.is_multimer: if self.is_multimer:
s = self.linear_out( s = self.linear_out(
torch.cat((o, *o_pt, o_pt_norm, o_pair), dim=-1).to(dtype=z.dtype) torch.cat((o, *o_pt, o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype)
) )
else: else:
s = self.linear_out( s = self.linear_out(
...@@ -874,7 +874,8 @@ class StructureModule(nn.Module): ...@@ -874,7 +874,8 @@ class StructureModule(nn.Module):
def _forward_multimer( def _forward_multimer(
self, self,
evoformer_output_dict, s: torch.Tensor,
z: torch.Tensor,
aatype: torch.Tensor, aatype: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
...@@ -898,6 +899,7 @@ class StructureModule(nn.Module): ...@@ -898,6 +899,7 @@ class StructureModule(nn.Module):
s.device, s.device,
) )
outputs = [] outputs = []
z = [z]
for i in range(self.no_blocks): for i in range(self.no_blocks):
# [*, N, C_s] # [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask) s = s + self.ipa(s, z, rigids, mask)
...@@ -960,7 +962,7 @@ class StructureModule(nn.Module): ...@@ -960,7 +962,7 @@ class StructureModule(nn.Module):
A dictionary of outputs A dictionary of outputs
""" """
if self.is_multimer: if self.is_multimer:
outputs = self._forward_multimer(evoformer_output_dict, aatype, mask) outputs = self._forward_multimer(evoformer_output_dict["single"], evoformer_output_dict["pair"], aatype, mask)
else: else:
outputs = self._forward_monomer(evoformer_output_dict, aatype, mask) outputs = self._forward_monomer(evoformer_output_dict, aatype, mask)
......
...@@ -127,10 +127,8 @@ def assign(translation_dict, orig_weights): ...@@ -127,10 +127,8 @@ def assign(translation_dict, orig_weights):
print(weights[0].shape) print(weights[0].shape)
raise raise
def get_translation_dict(model, version):
def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool = False): is_multimer = "multimer" in version
data = np.load(npz_path)
# translations = get_translation_dict(model, is_multimer=("multimer" in version))
####################### #######################
# Some templates # Some templates
####################### #######################
...@@ -540,8 +538,6 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool = ...@@ -540,8 +538,6 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool =
}, },
} }
# return translations
no_templ = [ no_templ = [
"model_3", "model_3",
"model_4", "model_4",
...@@ -557,11 +553,20 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool = ...@@ -557,11 +553,20 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool =
if "template_" in k: if "template_" in k:
evo_dict.pop(k) evo_dict.pop(k)
if "_ptm" in version: if "_ptm" in version or is_multimer:
translations["predicted_aligned_error_head"] = { translations["predicted_aligned_error_head"] = {
"logits": LinearParams(model.aux_heads.tm.linear) "logits": LinearParams(model.aux_heads.tm.linear)
} }
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(model, version)
# Flatten keys and insert missing key prefixes # Flatten keys and insert missing key prefixes
flat = _process_translations_dict(translations) flat = _process_translations_dict(translations)
...@@ -578,5 +583,3 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool = ...@@ -578,5 +583,3 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool =
# Set weights # Set weights
assign(flat, data) assign(flat, data)
...@@ -119,7 +119,8 @@ class FastFoldDataWorkFlow: ...@@ -119,7 +119,8 @@ class FastFoldDataWorkFlow:
def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None: def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None:
storage_dir = "file:///tmp/ray/" + os.getlogin() + "/workflow_data" timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
storage_dir = "file:///tmp/ray/" + str(timestamp) + "/workflow_data"
if storage_dir is not None: if storage_dir is not None:
if not os.path.exists(storage_dir): if not os.path.exists(storage_dir):
os.makedirs(storage_dir[7:], exist_ok=True) os.makedirs(storage_dir[7:], exist_ok=True)
......
...@@ -137,7 +137,8 @@ class FastFoldMultimerDataWorkFlow: ...@@ -137,7 +137,8 @@ class FastFoldMultimerDataWorkFlow:
def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None: def run(self, fasta_path: str, alignment_dir: str=None, storage_dir: str=None) -> None:
storage_dir = "file:///tmp/ray/" + os.getlogin() + "/workflow_data" timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
storage_dir = "file:///tmp/ray/" + str(timestamp) + "/workflow_data"
if storage_dir is not None: if storage_dir is not None:
if not os.path.exists(storage_dir): if not os.path.exists(storage_dir):
os.makedirs(storage_dir[7:], exist_ok=True) os.makedirs(storage_dir[7:], exist_ok=True)
......
...@@ -312,6 +312,7 @@ def inference_multimer_model(args): ...@@ -312,6 +312,7 @@ def inference_multimer_model(args):
with open(unrelaxed_output_path, 'w') as f: with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein)) f.write(protein.to_pdb(unrelaxed_protein))
if(args.relaxation):
amber_relaxer = relax.AmberRelaxation( amber_relaxer = relax.AmberRelaxation(
use_gpu=True, use_gpu=True,
**config.relax, **config.relax,
...@@ -328,6 +329,15 @@ def inference_multimer_model(args): ...@@ -328,6 +329,15 @@ def inference_multimer_model(args):
with open(relaxed_output_path, 'w') as f: with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str) f.write(relaxed_pdb_str)
if(args.save_outputs):
output_dict_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_output_dict.pkl'
)
with open(output_dict_path, "wb") as fp:
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Model output written to {output_dict_path}...")
def inference_monomer_model(args): def inference_monomer_model(args):
print("running in monomer mode...") print("running in monomer mode...")
...@@ -454,6 +464,7 @@ def inference_monomer_model(args): ...@@ -454,6 +464,7 @@ def inference_monomer_model(args):
with open(unrelaxed_output_path, 'w') as f: with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein)) f.write(protein.to_pdb(unrelaxed_protein))
if(args.relaxation):
amber_relaxer = relax.AmberRelaxation( amber_relaxer = relax.AmberRelaxation(
use_gpu=True, use_gpu=True,
**config.relax, **config.relax,
...@@ -512,6 +523,9 @@ if __name__ == "__main__": ...@@ -512,6 +523,9 @@ if __name__ == "__main__":
help="""Path to model parameters. If None, parameters are selected help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from automatically according to the model name from
./data/params""") ./data/params""")
parser.add_argument(
"--relaxation", action="store_false", default=False,
)
parser.add_argument( parser.add_argument(
"--save_outputs", action="store_true", default=False, "--save_outputs", action="store_true", default=False,
help="Whether to save all model outputs, including embeddings, etc." help="Whether to save all model outputs, including embeddings, etc."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment