"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "6df2505bf0352a7580b33f17ce6844afe04fb7be"
Unverified Commit 6d8b97ec authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

support multimer (#54)

parent 1efccb6c
......@@ -59,7 +59,7 @@ Run the following command to build a docker image from Dockerfile provided.
```shell
cd ColossalAI
docker build -t fastfold ./docker
docker build -t Fastfold ./docker
```
Run the following command to start the docker container in interactive mode.
......
......@@ -74,6 +74,24 @@ def model_config(name, train=False, low_prec=False):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "relax":
pass
elif "multimer" in name:
c.globals.is_multimer = True
c.data.predict.max_msa_clusters = 252 # 128 for monomer
c.model.structure_module.trans_scale_factor = 20 # 10 for monomer
for k, v in multimer_model_config_update.items():
c.model[k] = v
c.data.common.unsupervised_features.extend(
[
"msa_mask",
"seq_mask",
"asym_id",
"entity_id",
"sym_id",
]
)
else:
raise ValueError("Invalid model name")
......@@ -275,6 +293,7 @@ config = mlc.ConfigDict(
"c_e": c_e,
"c_s": c_s,
"eps": eps,
"is_multimer": False,
},
"model": {
"_mask_trans": False,
......@@ -494,4 +513,77 @@ config = mlc.ConfigDict(
},
"ema": {"decay": 0.999},
}
)
\ No newline at end of file
)
multimer_model_config_update = {
"input_embedder": {
"tf_dim": 21,
"msa_dim": 49,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
"max_relative_chain": 2,
"max_relative_idx": 32,
"use_chain_relative": True,
},
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_pair_embedder": {
"c_z": c_z,
"c_out": 64,
"c_dgram": 39,
"c_aatype": 22,
},
"template_single_embedder": {
"c_in": 34,
"c_m": c_m,
},
"template_pair_stack": {
"c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
},
"c_t": c_t,
"c_z": c_z,
"inf": 1e5, # 1e9,
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
},
"heads": {
"lddt": {
"no_bins": 50,
"c_in": c_s,
"c_hidden": 128,
},
"distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
},
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": tm_enabled,
},
"masked_msa": {
"c_m": c_m,
"c_out": 22,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
},
},
}
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import os
import subprocess
import logging
from typing import Optional, Sequence
from fastfold.data import parsers
from fastfold.data.tools import hmmbuild
from fastfold.utils import general_utils as utils
class Hmmsearch(object):
"""Python wrapper of the hmmsearch binary."""
def __init__(
self,
*,
binary_path: str,
hmmbuild_binary_path: str,
database_path: str,
flags: Optional[Sequence[str]] = None,
):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self.binary_path = binary_path
self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
self.database_path = database_path
if flags is None:
# Default hmmsearch run settings.
flags = [
"--F1",
"0.1",
"--F2",
"0.1",
"--F3",
"0.1",
"--incE",
"100",
"-E",
"100",
"--domE",
"100",
"--incdomE",
"100",
]
self.flags = flags
if not os.path.exists(self.database_path):
logging.error("Could not find hmmsearch database %s", database_path)
raise ValueError(f"Could not find hmmsearch database {database_path}")
@property
def output_format(self) -> str:
return "sto"
@property
def input_format(self) -> str:
return "sto"
def query(self, msa_sto: str, output_dir: Optional[str] = None) -> str:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm = self.hmmbuild_runner.build_profile_from_sto(
msa_sto, model_construction="hand"
)
return self.query_with_hmm(hmm, output_dir)
def query_with_hmm(self, hmm: str, output_dir: Optional[str] = None) -> str:
"""Queries the database using hmmsearch using a given hmm."""
with utils.tmpdir_manager() as query_tmp_dir:
hmm_input_path = os.path.join(query_tmp_dir, "query.hmm")
output_dir = query_tmp_dir if output_dir is None else output_dir
out_path = os.path.join(output_dir, "hmm_output.sto")
with open(hmm_input_path, "w") as f:
f.write(hmm)
cmd = [
self.binary_path,
"--noali", # Don't include the alignment in stdout.
"--cpu",
"8",
]
# If adding flags, we have to do so before the output and input:
if self.flags:
cmd.extend(self.flags)
cmd.extend(
[
"-A",
out_path,
hmm_input_path,
self.database_path,
]
)
logging.info("Launching sub-process %s", cmd)
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing(
f"hmmsearch ({os.path.basename(self.database_path)}) query"
):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
raise RuntimeError(
"hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr.decode("utf-8"))
)
with open(out_path) as f:
out_msa = f.read()
return out_msa
@staticmethod
def get_template_hits(
output_string: str, input_sequence: str
) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
template_hits = parsers.parse_hmmsearch_sto(
output_string,
input_sequence,
)
return template_hits
......@@ -43,7 +43,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor
if dim == 1:
if dim == 1 and list(tensor.shape)[0] == 1:
output_shape = list(tensor.shape)
output_shape[1] *= gpc.get_world_size(ParallelMode.TENSOR)
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
......
......@@ -32,7 +32,7 @@ from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_O
class EvoformerBlock(nn.Module):
def __init__(self, c_m: int, c_z: int, first_block: bool, last_block: bool):
def __init__(self, c_m: int, c_z: int, first_block: bool, last_block: bool, is_multimer: bool=False):
super(EvoformerBlock, self).__init__()
self.first_block = first_block
......@@ -41,6 +41,7 @@ class EvoformerBlock(nn.Module):
self.msa_stack = MSAStack(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair_stack = PairStack(d_pair=c_z)
self.is_multimer = is_multimer
def forward(
self,
......@@ -73,12 +74,19 @@ class EvoformerBlock(nn.Module):
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
m = self.msa_stack(m, z, msa_mask)
z = z + self.communication(m, msa_mask)
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
if not self.is_multimer:
m = self.msa_stack(m, z, msa_mask)
z = z + self.communication(m, msa_mask)
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
else:
z = z + self.communication(m, msa_mask)
z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
m = self.msa_stack(m, z_ori, msa_mask)
if self.last_block:
m = m.squeeze(0)
......@@ -260,7 +268,6 @@ class TemplatePairStackBlock(nn.Module):
single_templates[i] = single
z = torch.cat(single_templates, dim=-4)
if self.last_block:
z = gather(z, dim=1)
z = z[:, :-padding_size, :-padding_size, :]
......
......@@ -26,6 +26,7 @@ from fastfold.utils.feats import (
)
from fastfold.model.nn.embedders import (
InputEmbedder,
InputEmbedderMultimer,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
......@@ -69,9 +70,14 @@ class AlphaFold(nn.Module):
extra_msa_config = config.extra_msa
# Main trunk + structure module
self.input_embedder = InputEmbedder(
**config["input_embedder"],
)
if self.globals.is_multimer:
self.input_embedder = InputEmbedderMultimer(
**config["input_embedder"],
)
else:
self.input_embedder = InputEmbedder(
**config["input_embedder"],
)
self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"],
)
......
......@@ -15,7 +15,7 @@
import torch
import torch.nn as nn
from typing import Tuple
from typing import Tuple, Dict
from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.utils.tensor_utils import one_hot
......@@ -125,6 +125,146 @@ class InputEmbedder(nn.Module):
return msa_emb, pair_emb
class InputEmbedderMultimer(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
max_relative_idx: int,
use_chain_relative: bool,
max_relative_chain: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedderMultimer, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.max_relative_idx = max_relative_idx
self.use_chain_relative = use_chain_relative
self.max_relative_chain = max_relative_chain
if self.use_chain_relative:
self.no_bins = 2 * max_relative_idx + 2 + 1 + 2 * max_relative_chain + 2
else:
self.no_bins = 2 * max_relative_idx + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, batch: Dict[str, torch.Tensor]):
pos = batch["residue_index"]
asym_id = batch["asym_id"]
asym_id_same = asym_id[..., None] == asym_id[..., None, :]
offset = pos[..., None] - pos[..., None, :]
clipped_offset = torch.clamp(
offset + self.max_relative_idx, 0, 2 * self.max_relative_idx
)
rel_feats = []
if self.use_chain_relative:
final_offset = torch.where(
asym_id_same,
clipped_offset,
(2 * self.max_relative_idx + 1) * torch.ones_like(clipped_offset),
)
rel_pos = torch.nn.functional.one_hot(
final_offset,
2 * self.max_relative_idx + 2,
)
rel_feats.append(rel_pos)
entity_id = batch["entity_id"]
entity_id_same = entity_id[..., None] == entity_id[..., None, :]
rel_feats.append(entity_id_same[..., None])
sym_id = batch["sym_id"]
rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
max_rel_chain = self.max_relative_chain
clipped_rel_chain = torch.clamp(
rel_sym_id + max_rel_chain,
0,
2 * max_rel_chain,
)
final_rel_chain = torch.where(
entity_id_same,
clipped_rel_chain,
(2 * max_rel_chain + 1) * torch.ones_like(clipped_rel_chain),
)
rel_chain = torch.nn.functional.one_hot(
final_rel_chain.long(),
2 * max_rel_chain + 2,
)
rel_feats.append(rel_chain)
else:
rel_pos = torch.nn.functional.one_hot(
clipped_offset,
2 * self.max_relative_idx + 1,
)
rel_feats.append(rel_pos)
rel_feat = torch.cat(rel_feats, dim=-1).to(self.linear_relpos.weight.dtype)
return self.linear_relpos(rel_feat)
def forward(
self, batch: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
tf = batch["target_feat"]
msa = batch["msa_feat"]
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(batch)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module):
"""
......
......@@ -142,6 +142,7 @@ class TemplatePairStackBlock(nn.Module):
pair_transition_n: int,
dropout_rate: float,
inf: float,
is_multimer: bool=False,
**kwargs,
):
super(TemplatePairStackBlock, self).__init__()
......@@ -153,6 +154,7 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n = pair_transition_n
self.dropout_rate = dropout_rate
self.inf = inf
self.is_multimer = is_multimer
self.dropout_row = DropoutRowwise(self.dropout_rate)
self.dropout_col = DropoutColumnwise(self.dropout_rate)
......@@ -196,43 +198,67 @@ class TemplatePairStackBlock(nn.Module):
single_templates_masks = [
m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)
]
for i in range(len(single_templates)):
single = single_templates[i]
single_mask = single_templates_masks[i]
single = single + self.dropout_row(
self.tri_att_start(
single,
chunk_size=chunk_size,
mask=single_mask
if not self.is_multimer:
for i in range(len(single_templates)):
single = single_templates[i]
single_mask = single_templates_masks[i]
single = single + self.dropout_row(
self.tri_att_start(
single,
chunk_size=chunk_size,
mask=single_mask
)
)
)
single = single + self.dropout_col(
self.tri_att_end(
single = single + self.dropout_col(
self.tri_att_end(
single,
chunk_size=chunk_size,
mask=single_mask
)
)
single = single + self.dropout_row(
self.tri_mul_out(
single,
mask=single_mask
)
)
single = single + self.dropout_row(
self.tri_mul_in(
single,
mask=single_mask
)
)
single = single + self.pair_transition(
single,
mask=single_mask if _mask_trans else None,
chunk_size=chunk_size,
mask=single_mask
)
)
single = single + self.dropout_row(
self.tri_mul_out(
single,
mask=single_mask
single_templates[i] = single
else:
for i in range(len(single_templates)):
single = single_templates[i]
single_mask = single_templates_masks[i]
single = single + self.dropout_row(
self.tri_att_start(single, chunk_size=chunk_size, mask=single_mask)
)
)
single = single + self.dropout_row(
self.tri_mul_in(
single = single + self.dropout_col(
self.tri_att_end(single, chunk_size=chunk_size, mask=single_mask)
)
single = single + self.dropout_row(
self.tri_mul_out(single, mask=single_mask)
)
single = single + self.dropout_row(
self.tri_mul_in(single, mask=single_mask)
)
single = single + self.pair_transition(
single,
mask=single_mask
mask=single_mask if _mask_trans else None,
chunk_size=chunk_size,
)
)
single = single + self.pair_transition(
single,
mask=single_mask if _mask_trans else None,
chunk_size=chunk_size,
)
single_templates[i] = single
single_templates[i] = single
z = torch.cat(single_templates, dim=-4)
......
import os
import time
from multiprocessing import cpu_count
import ray
from ray import workflow
from fastfold.workflow.factory import JackHmmerFactory, HHSearchFactory, HHBlitsFactory
from fastfold.workflow import batch_run
......@@ -80,6 +81,7 @@ class FastFoldDataWorkFlow:
print("Workflow not found. Clean. Skipping")
pass
# prepare alignment directory for alignment outputs
if alignment_dir is None:
alignment_dir = os.path.join(output_dir, "alignment")
......
......@@ -23,6 +23,7 @@ from datetime import date
import numpy as np
import torch
import torch.multiprocessing as mp
import pickle
from fastfold.model.hub import AlphaFold
import fastfold
......@@ -111,6 +112,7 @@ def inference_model(rank, world_size, result_q, batch, args):
def main(args):
config = model_config(args.model_name)
global_is_multimer = True if args.model_preset == "multimer" else False
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
......@@ -147,6 +149,7 @@ def main(args):
seqs, tags = parse_fasta(fasta)
for tag, seq in zip(tags, seqs):
print(f"tag:{tag} seq:{seq}")
batch = [None]
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
......@@ -155,44 +158,48 @@ def main(args):
print("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag)
if (args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
if args.enable_workflow:
print("Running alignment with ray workflow...")
alignment_data_workflow_runner = FastFoldDataWorkFlow(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
t = time.perf_counter()
alignment_data_workflow_runner.run(fasta_path, output_dir=output_dir_base, alignment_dir=local_alignment_dir)
print(f"Alignment data workflow time: {time.perf_counter() - t}")
else:
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(fasta_path, local_alignment_dir)
feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
alignment_dir=local_alignment_dir)
# if global_is_multimer:
# print("Multimer")
# else:
# if (args.use_precomputed_alignments is None):
# if not os.path.exists(local_alignment_dir):
# os.makedirs(local_alignment_dir)
# if args.enable_workflow:
# print("Running alignment with ray workflow...")
# alignment_data_workflow_runner = FastFoldDataWorkFlow(
# jackhmmer_binary_path=args.jackhmmer_binary_path,
# hhblits_binary_path=args.hhblits_binary_path,
# hhsearch_binary_path=args.hhsearch_binary_path,
# uniref90_database_path=args.uniref90_database_path,
# mgnify_database_path=args.mgnify_database_path,
# bfd_database_path=args.bfd_database_path,
# uniclust30_database_path=args.uniclust30_database_path,
# pdb70_database_path=args.pdb70_database_path,
# use_small_bfd=use_small_bfd,
# no_cpus=args.cpus,
# )
# t = time.perf_counter()
# alignment_data_workflow_runner.run(fasta_path, output_dir=output_dir_base, alignment_dir=local_alignment_dir)
# print(f"Alignment data workflow time: {time.perf_counter() - t}")
# else:
# alignment_runner = data_pipeline.AlignmentRunner(
# jackhmmer_binary_path=args.jackhmmer_binary_path,
# hhblits_binary_path=args.hhblits_binary_path,
# hhsearch_binary_path=args.hhsearch_binary_path,
# uniref90_database_path=args.uniref90_database_path,
# mgnify_database_path=args.mgnify_database_path,
# bfd_database_path=args.bfd_database_path,
# uniclust30_database_path=args.uniclust30_database_path,
# pdb70_database_path=args.pdb70_database_path,
# use_small_bfd=use_small_bfd,
# no_cpus=args.cpus,
# )
# alignment_runner.run(fasta_path, local_alignment_dir)
# feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
# alignment_dir=local_alignment_dir)
feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
# Remove temporary FASTA file
os.remove(fasta_path)
......@@ -289,6 +296,14 @@ if __name__ == "__main__":
default='full_dbs',
choices=('reduced_dbs', 'full_dbs'))
parser.add_argument('--data_random_seed', type=str, default=None)
parser.add_argument(
"--model_preset",
type=str,
default="monomer",
choices=["monomer", "multimer"],
help="Choose preset model configuration - the monomer model, the monomer model with "
"extra ensembling, monomer model with pTM head, or multimer model",
)
add_data_args(parser)
args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment