Unverified Commit 6d8b97ec authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

support multimer (#54)

parent 1efccb6c
...@@ -59,7 +59,7 @@ Run the following command to build a docker image from Dockerfile provided. ...@@ -59,7 +59,7 @@ Run the following command to build a docker image from Dockerfile provided.
```shell ```shell
cd ColossalAI cd ColossalAI
docker build -t fastfold ./docker docker build -t Fastfold ./docker
``` ```
Run the following command to start the docker container in interactive mode. Run the following command to start the docker container in interactive mode.
......
...@@ -74,6 +74,24 @@ def model_config(name, train=False, low_prec=False): ...@@ -74,6 +74,24 @@ def model_config(name, train=False, low_prec=False):
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif name == "relax":
pass
elif "multimer" in name:
c.globals.is_multimer = True
c.data.predict.max_msa_clusters = 252 # 128 for monomer
c.model.structure_module.trans_scale_factor = 20 # 10 for monomer
for k, v in multimer_model_config_update.items():
c.model[k] = v
c.data.common.unsupervised_features.extend(
[
"msa_mask",
"seq_mask",
"asym_id",
"entity_id",
"sym_id",
]
)
else: else:
raise ValueError("Invalid model name") raise ValueError("Invalid model name")
...@@ -275,6 +293,7 @@ config = mlc.ConfigDict( ...@@ -275,6 +293,7 @@ config = mlc.ConfigDict(
"c_e": c_e, "c_e": c_e,
"c_s": c_s, "c_s": c_s,
"eps": eps, "eps": eps,
"is_multimer": False,
}, },
"model": { "model": {
"_mask_trans": False, "_mask_trans": False,
...@@ -495,3 +514,76 @@ config = mlc.ConfigDict( ...@@ -495,3 +514,76 @@ config = mlc.ConfigDict(
"ema": {"decay": 0.999}, "ema": {"decay": 0.999},
} }
) )
multimer_model_config_update = {
"input_embedder": {
"tf_dim": 21,
"msa_dim": 49,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
"max_relative_chain": 2,
"max_relative_idx": 32,
"use_chain_relative": True,
},
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_pair_embedder": {
"c_z": c_z,
"c_out": 64,
"c_dgram": 39,
"c_aatype": 22,
},
"template_single_embedder": {
"c_in": 34,
"c_m": c_m,
},
"template_pair_stack": {
"c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
},
"c_t": c_t,
"c_z": c_z,
"inf": 1e5, # 1e9,
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
},
"heads": {
"lddt": {
"no_bins": 50,
"c_in": c_s,
"c_hidden": 128,
},
"distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
},
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": tm_enabled,
},
"masked_msa": {
"c_m": c_m,
"c_out": 22,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
},
},
}
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import os
import subprocess
import logging
from typing import Optional, Sequence
from fastfold.data import parsers
from fastfold.data.tools import hmmbuild
from fastfold.utils import general_utils as utils
class Hmmsearch(object):
"""Python wrapper of the hmmsearch binary."""
def __init__(
self,
*,
binary_path: str,
hmmbuild_binary_path: str,
database_path: str,
flags: Optional[Sequence[str]] = None,
):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self.binary_path = binary_path
self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
self.database_path = database_path
if flags is None:
# Default hmmsearch run settings.
flags = [
"--F1",
"0.1",
"--F2",
"0.1",
"--F3",
"0.1",
"--incE",
"100",
"-E",
"100",
"--domE",
"100",
"--incdomE",
"100",
]
self.flags = flags
if not os.path.exists(self.database_path):
logging.error("Could not find hmmsearch database %s", database_path)
raise ValueError(f"Could not find hmmsearch database {database_path}")
@property
def output_format(self) -> str:
return "sto"
@property
def input_format(self) -> str:
return "sto"
def query(self, msa_sto: str, output_dir: Optional[str] = None) -> str:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm = self.hmmbuild_runner.build_profile_from_sto(
msa_sto, model_construction="hand"
)
return self.query_with_hmm(hmm, output_dir)
def query_with_hmm(self, hmm: str, output_dir: Optional[str] = None) -> str:
"""Queries the database using hmmsearch using a given hmm."""
with utils.tmpdir_manager() as query_tmp_dir:
hmm_input_path = os.path.join(query_tmp_dir, "query.hmm")
output_dir = query_tmp_dir if output_dir is None else output_dir
out_path = os.path.join(output_dir, "hmm_output.sto")
with open(hmm_input_path, "w") as f:
f.write(hmm)
cmd = [
self.binary_path,
"--noali", # Don't include the alignment in stdout.
"--cpu",
"8",
]
# If adding flags, we have to do so before the output and input:
if self.flags:
cmd.extend(self.flags)
cmd.extend(
[
"-A",
out_path,
hmm_input_path,
self.database_path,
]
)
logging.info("Launching sub-process %s", cmd)
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing(
f"hmmsearch ({os.path.basename(self.database_path)}) query"
):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
raise RuntimeError(
"hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr.decode("utf-8"))
)
with open(out_path) as f:
out_msa = f.read()
return out_msa
@staticmethod
def get_template_hits(
output_string: str, input_sequence: str
) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
template_hits = parsers.parse_hmmsearch_sto(
output_string,
input_sequence,
)
return template_hits
...@@ -43,7 +43,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor: ...@@ -43,7 +43,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1: if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor return tensor
if dim == 1: if dim == 1 and list(tensor.shape)[0] == 1:
output_shape = list(tensor.shape) output_shape = list(tensor.shape)
output_shape[1] *= gpc.get_world_size(ParallelMode.TENSOR) output_shape[1] *= gpc.get_world_size(ParallelMode.TENSOR)
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device) output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
......
...@@ -32,7 +32,7 @@ from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_O ...@@ -32,7 +32,7 @@ from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_O
class EvoformerBlock(nn.Module): class EvoformerBlock(nn.Module):
def __init__(self, c_m: int, c_z: int, first_block: bool, last_block: bool): def __init__(self, c_m: int, c_z: int, first_block: bool, last_block: bool, is_multimer: bool=False):
super(EvoformerBlock, self).__init__() super(EvoformerBlock, self).__init__()
self.first_block = first_block self.first_block = first_block
...@@ -41,6 +41,7 @@ class EvoformerBlock(nn.Module): ...@@ -41,6 +41,7 @@ class EvoformerBlock(nn.Module):
self.msa_stack = MSAStack(c_m, c_z, p_drop=0.15) self.msa_stack = MSAStack(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32) self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair_stack = PairStack(d_pair=c_z) self.pair_stack = PairStack(d_pair=c_z)
self.is_multimer = is_multimer
def forward( def forward(
self, self,
...@@ -73,12 +74,19 @@ class EvoformerBlock(nn.Module): ...@@ -73,12 +74,19 @@ class EvoformerBlock(nn.Module):
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size)) msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size)) pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
if not self.is_multimer:
m = self.msa_stack(m, z, msa_mask) m = self.msa_stack(m, z, msa_mask)
z = z + self.communication(m, msa_mask) z = z + self.communication(m, msa_mask)
m, work = All_to_All_Async.apply(m, 1, 2) m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask) z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2) m = All_to_All_Async_Opp.apply(m, work, 1, 2)
else:
z = z + self.communication(m, msa_mask)
z_ori = z
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
m = self.msa_stack(m, z_ori, msa_mask)
if self.last_block: if self.last_block:
m = m.squeeze(0) m = m.squeeze(0)
...@@ -260,7 +268,6 @@ class TemplatePairStackBlock(nn.Module): ...@@ -260,7 +268,6 @@ class TemplatePairStackBlock(nn.Module):
single_templates[i] = single single_templates[i] = single
z = torch.cat(single_templates, dim=-4) z = torch.cat(single_templates, dim=-4)
if self.last_block: if self.last_block:
z = gather(z, dim=1) z = gather(z, dim=1)
z = z[:, :-padding_size, :-padding_size, :] z = z[:, :-padding_size, :-padding_size, :]
......
...@@ -26,6 +26,7 @@ from fastfold.utils.feats import ( ...@@ -26,6 +26,7 @@ from fastfold.utils.feats import (
) )
from fastfold.model.nn.embedders import ( from fastfold.model.nn.embedders import (
InputEmbedder, InputEmbedder,
InputEmbedderMultimer,
RecyclingEmbedder, RecyclingEmbedder,
TemplateAngleEmbedder, TemplateAngleEmbedder,
TemplatePairEmbedder, TemplatePairEmbedder,
...@@ -69,6 +70,11 @@ class AlphaFold(nn.Module): ...@@ -69,6 +70,11 @@ class AlphaFold(nn.Module):
extra_msa_config = config.extra_msa extra_msa_config = config.extra_msa
# Main trunk + structure module # Main trunk + structure module
if self.globals.is_multimer:
self.input_embedder = InputEmbedderMultimer(
**config["input_embedder"],
)
else:
self.input_embedder = InputEmbedder( self.input_embedder = InputEmbedder(
**config["input_embedder"], **config["input_embedder"],
) )
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple from typing import Tuple, Dict
from fastfold.model.nn.primitives import Linear, LayerNorm from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.utils.tensor_utils import one_hot from fastfold.utils.tensor_utils import one_hot
...@@ -125,6 +125,146 @@ class InputEmbedder(nn.Module): ...@@ -125,6 +125,146 @@ class InputEmbedder(nn.Module):
return msa_emb, pair_emb return msa_emb, pair_emb
class InputEmbedderMultimer(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
max_relative_idx: int,
use_chain_relative: bool,
max_relative_chain: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedderMultimer, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.max_relative_idx = max_relative_idx
self.use_chain_relative = use_chain_relative
self.max_relative_chain = max_relative_chain
if self.use_chain_relative:
self.no_bins = 2 * max_relative_idx + 2 + 1 + 2 * max_relative_chain + 2
else:
self.no_bins = 2 * max_relative_idx + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, batch: Dict[str, torch.Tensor]):
pos = batch["residue_index"]
asym_id = batch["asym_id"]
asym_id_same = asym_id[..., None] == asym_id[..., None, :]
offset = pos[..., None] - pos[..., None, :]
clipped_offset = torch.clamp(
offset + self.max_relative_idx, 0, 2 * self.max_relative_idx
)
rel_feats = []
if self.use_chain_relative:
final_offset = torch.where(
asym_id_same,
clipped_offset,
(2 * self.max_relative_idx + 1) * torch.ones_like(clipped_offset),
)
rel_pos = torch.nn.functional.one_hot(
final_offset,
2 * self.max_relative_idx + 2,
)
rel_feats.append(rel_pos)
entity_id = batch["entity_id"]
entity_id_same = entity_id[..., None] == entity_id[..., None, :]
rel_feats.append(entity_id_same[..., None])
sym_id = batch["sym_id"]
rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
max_rel_chain = self.max_relative_chain
clipped_rel_chain = torch.clamp(
rel_sym_id + max_rel_chain,
0,
2 * max_rel_chain,
)
final_rel_chain = torch.where(
entity_id_same,
clipped_rel_chain,
(2 * max_rel_chain + 1) * torch.ones_like(clipped_rel_chain),
)
rel_chain = torch.nn.functional.one_hot(
final_rel_chain.long(),
2 * max_rel_chain + 2,
)
rel_feats.append(rel_chain)
else:
rel_pos = torch.nn.functional.one_hot(
clipped_offset,
2 * self.max_relative_idx + 1,
)
rel_feats.append(rel_pos)
rel_feat = torch.cat(rel_feats, dim=-1).to(self.linear_relpos.weight.dtype)
return self.linear_relpos(rel_feat)
def forward(
self, batch: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
tf = batch["target_feat"]
msa = batch["msa_feat"]
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(batch)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module): class RecyclingEmbedder(nn.Module):
""" """
......
...@@ -142,6 +142,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -142,6 +142,7 @@ class TemplatePairStackBlock(nn.Module):
pair_transition_n: int, pair_transition_n: int,
dropout_rate: float, dropout_rate: float,
inf: float, inf: float,
is_multimer: bool=False,
**kwargs, **kwargs,
): ):
super(TemplatePairStackBlock, self).__init__() super(TemplatePairStackBlock, self).__init__()
...@@ -153,6 +154,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -153,6 +154,7 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n = pair_transition_n self.pair_transition_n = pair_transition_n
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.inf = inf self.inf = inf
self.is_multimer = is_multimer
self.dropout_row = DropoutRowwise(self.dropout_rate) self.dropout_row = DropoutRowwise(self.dropout_rate)
self.dropout_col = DropoutColumnwise(self.dropout_rate) self.dropout_col = DropoutColumnwise(self.dropout_rate)
...@@ -196,6 +198,7 @@ class TemplatePairStackBlock(nn.Module): ...@@ -196,6 +198,7 @@ class TemplatePairStackBlock(nn.Module):
single_templates_masks = [ single_templates_masks = [
m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3) m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)
] ]
if not self.is_multimer:
for i in range(len(single_templates)): for i in range(len(single_templates)):
single = single_templates[i] single = single_templates[i]
single_mask = single_templates_masks[i] single_mask = single_templates_masks[i]
...@@ -233,6 +236,29 @@ class TemplatePairStackBlock(nn.Module): ...@@ -233,6 +236,29 @@ class TemplatePairStackBlock(nn.Module):
) )
single_templates[i] = single single_templates[i] = single
else:
for i in range(len(single_templates)):
single = single_templates[i]
single_mask = single_templates_masks[i]
single = single + self.dropout_row(
self.tri_att_start(single, chunk_size=chunk_size, mask=single_mask)
)
single = single + self.dropout_col(
self.tri_att_end(single, chunk_size=chunk_size, mask=single_mask)
)
single = single + self.dropout_row(
self.tri_mul_out(single, mask=single_mask)
)
single = single + self.dropout_row(
self.tri_mul_in(single, mask=single_mask)
)
single = single + self.pair_transition(
single,
mask=single_mask if _mask_trans else None,
chunk_size=chunk_size,
)
single_templates[i] = single
z = torch.cat(single_templates, dim=-4) z = torch.cat(single_templates, dim=-4)
......
import os import os
import time import time
from multiprocessing import cpu_count from multiprocessing import cpu_count
import ray
from ray import workflow from ray import workflow
from fastfold.workflow.factory import JackHmmerFactory, HHSearchFactory, HHBlitsFactory from fastfold.workflow.factory import JackHmmerFactory, HHSearchFactory, HHBlitsFactory
from fastfold.workflow import batch_run from fastfold.workflow import batch_run
...@@ -80,6 +81,7 @@ class FastFoldDataWorkFlow: ...@@ -80,6 +81,7 @@ class FastFoldDataWorkFlow:
print("Workflow not found. Clean. Skipping") print("Workflow not found. Clean. Skipping")
pass pass
# prepare alignment directory for alignment outputs # prepare alignment directory for alignment outputs
if alignment_dir is None: if alignment_dir is None:
alignment_dir = os.path.join(output_dir, "alignment") alignment_dir = os.path.join(output_dir, "alignment")
......
...@@ -23,6 +23,7 @@ from datetime import date ...@@ -23,6 +23,7 @@ from datetime import date
import numpy as np import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import pickle
from fastfold.model.hub import AlphaFold from fastfold.model.hub import AlphaFold
import fastfold import fastfold
...@@ -111,6 +112,7 @@ def inference_model(rank, world_size, result_q, batch, args): ...@@ -111,6 +112,7 @@ def inference_model(rank, world_size, result_q, batch, args):
def main(args): def main(args):
config = model_config(args.model_name) config = model_config(args.model_name)
global_is_multimer = True if args.model_preset == "multimer" else False
template_featurizer = templates.TemplateHitFeaturizer( template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
...@@ -147,6 +149,7 @@ def main(args): ...@@ -147,6 +149,7 @@ def main(args):
seqs, tags = parse_fasta(fasta) seqs, tags = parse_fasta(fasta)
for tag, seq in zip(tags, seqs): for tag, seq in zip(tags, seqs):
print(f"tag:{tag} seq:{seq}")
batch = [None] batch = [None]
fasta_path = os.path.join(args.output_dir, "tmp.fasta") fasta_path = os.path.join(args.output_dir, "tmp.fasta")
...@@ -155,44 +158,48 @@ def main(args): ...@@ -155,44 +158,48 @@ def main(args):
print("Generating features...") print("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag) local_alignment_dir = os.path.join(alignment_dir, tag)
if (args.use_precomputed_alignments is None): # if global_is_multimer:
if not os.path.exists(local_alignment_dir): # print("Multimer")
os.makedirs(local_alignment_dir) # else:
if args.enable_workflow: # if (args.use_precomputed_alignments is None):
print("Running alignment with ray workflow...") # if not os.path.exists(local_alignment_dir):
alignment_data_workflow_runner = FastFoldDataWorkFlow( # os.makedirs(local_alignment_dir)
jackhmmer_binary_path=args.jackhmmer_binary_path, # if args.enable_workflow:
hhblits_binary_path=args.hhblits_binary_path, # print("Running alignment with ray workflow...")
hhsearch_binary_path=args.hhsearch_binary_path, # alignment_data_workflow_runner = FastFoldDataWorkFlow(
uniref90_database_path=args.uniref90_database_path, # jackhmmer_binary_path=args.jackhmmer_binary_path,
mgnify_database_path=args.mgnify_database_path, # hhblits_binary_path=args.hhblits_binary_path,
bfd_database_path=args.bfd_database_path, # hhsearch_binary_path=args.hhsearch_binary_path,
uniclust30_database_path=args.uniclust30_database_path, # uniref90_database_path=args.uniref90_database_path,
pdb70_database_path=args.pdb70_database_path, # mgnify_database_path=args.mgnify_database_path,
use_small_bfd=use_small_bfd, # bfd_database_path=args.bfd_database_path,
no_cpus=args.cpus, # uniclust30_database_path=args.uniclust30_database_path,
) # pdb70_database_path=args.pdb70_database_path,
t = time.perf_counter() # use_small_bfd=use_small_bfd,
alignment_data_workflow_runner.run(fasta_path, output_dir=output_dir_base, alignment_dir=local_alignment_dir) # no_cpus=args.cpus,
print(f"Alignment data workflow time: {time.perf_counter() - t}") # )
else: # t = time.perf_counter()
alignment_runner = data_pipeline.AlignmentRunner( # alignment_data_workflow_runner.run(fasta_path, output_dir=output_dir_base, alignment_dir=local_alignment_dir)
jackhmmer_binary_path=args.jackhmmer_binary_path, # print(f"Alignment data workflow time: {time.perf_counter() - t}")
hhblits_binary_path=args.hhblits_binary_path, # else:
hhsearch_binary_path=args.hhsearch_binary_path, # alignment_runner = data_pipeline.AlignmentRunner(
uniref90_database_path=args.uniref90_database_path, # jackhmmer_binary_path=args.jackhmmer_binary_path,
mgnify_database_path=args.mgnify_database_path, # hhblits_binary_path=args.hhblits_binary_path,
bfd_database_path=args.bfd_database_path, # hhsearch_binary_path=args.hhsearch_binary_path,
uniclust30_database_path=args.uniclust30_database_path, # uniref90_database_path=args.uniref90_database_path,
pdb70_database_path=args.pdb70_database_path, # mgnify_database_path=args.mgnify_database_path,
use_small_bfd=use_small_bfd, # bfd_database_path=args.bfd_database_path,
no_cpus=args.cpus, # uniclust30_database_path=args.uniclust30_database_path,
) # pdb70_database_path=args.pdb70_database_path,
alignment_runner.run(fasta_path, local_alignment_dir) # use_small_bfd=use_small_bfd,
# no_cpus=args.cpus,
feature_dict = data_processor.process_fasta(fasta_path=fasta_path, # )
alignment_dir=local_alignment_dir) # alignment_runner.run(fasta_path, local_alignment_dir)
# feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
# alignment_dir=local_alignment_dir)
feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
# Remove temporary FASTA file # Remove temporary FASTA file
os.remove(fasta_path) os.remove(fasta_path)
...@@ -289,6 +296,14 @@ if __name__ == "__main__": ...@@ -289,6 +296,14 @@ if __name__ == "__main__":
default='full_dbs', default='full_dbs',
choices=('reduced_dbs', 'full_dbs')) choices=('reduced_dbs', 'full_dbs'))
parser.add_argument('--data_random_seed', type=str, default=None) parser.add_argument('--data_random_seed', type=str, default=None)
parser.add_argument(
"--model_preset",
type=str,
default="monomer",
choices=["monomer", "multimer"],
help="Choose preset model configuration - the monomer model, the monomer model with "
"extra ensembling, monomer model with pTM head, or multimer model",
)
add_data_args(parser) add_data_args(parser)
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment