Commit 4bd1b4d5 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Work on multimer continues

parent 54164fe8
......@@ -89,20 +89,32 @@ def build_template_angle_feat(template_feats):
return template_angle_feat
def dgram_from_positions(
pos: torch.Tensor,
min_bin: float = 3.25,
max_bin: float = 50.75,
no_bins: float = 39,
inf: float = 1e8,
):
dgram = torch.sum(
(pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
return dgram
def build_template_pair_feat(
batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e8
batch, min_bin, max_bin, no_bins, use_unit_vector=False, eps=1e-20, inf=1e8
):
template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
# Compute distogram (this seems to differ slightly from Alg. 5)
tpb = batch["template_pseudo_beta"]
dgram = torch.sum(
(tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
upper = torch.cat([lower[:-1], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
dgram = dgram_from_positions(tpb, min_bin, max_bin, no_bins, inf)
to_concat = [dgram, template_mask_2d[..., None]]
......@@ -143,6 +155,10 @@ def build_template_pair_feat(
inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = rigid_vec * inv_distance_scalar[..., None]
if(not use_unit_vector):
unit_vector = unit_vector * 0.
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None])
......@@ -159,7 +175,7 @@ def build_extra_msa_feat(batch):
batch["extra_has_deletion"].unsqueeze(-1),
batch["extra_deletion_value"].unsqueeze(-1),
]
return torch.cat(msa_feat, dim=-1)
return msa_feat
def torsion_angles_to_frames(
......
......@@ -39,6 +39,13 @@ class ParamType(Enum):
LinearWeightOPM = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
)
LinearWeightMultimer = partial(
lambda w: w.unsqueeze(-1) if len(w.shape) == 1 else
w.reshape(w.shape[0], -1).transpose(-1, -2)
)
LinearBiasMultimer = partial(
lambda w: w.reshape(-1)
)
Other = partial(lambda w: w)
def __init__(self, fn):
......@@ -122,28 +129,32 @@ def assign(translation_dict, orig_weights):
raise
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
def get_translation_dict(model, is_multimer=False):
#######################
# Some templates
#######################
LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))
LinearBias = lambda l: (Param(l))
LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))
LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))
LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))
LinearWeightMultimer = lambda l: (
Param(l, param_type=ParamType.LinearWeightMultimer)
)
LinearBiasMultimer = lambda l: (
Param(l, param_type=ParamType.LinearBiasMultimer)
)
LinearParams = lambda l: {
"weights": LinearWeight(l.weight),
"bias": LinearBias(l.bias),
}
LinearParamsMultimer = lambda l: {
"weights": LinearWeightMultimer(l.weight),
"bias": LinearBiasMultimer(l.bias),
}
LayerNormParams = lambda l: {
"scale": Param(l.weight),
"offset": Param(l.bias),
......@@ -236,10 +247,48 @@ def import_jax_weights_(model, npz_path, version="model_1"):
)
IPAParams = lambda ipa: {
"q_scalar": LinearParams(ipa.linear_q),
"q_scalar_projection": LinearParams(ipa.linear_q),
"kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points),
"kv_point_local": LinearParams(ipa.linear_kv_points),
"q_point_local": LinearParams(ipa.linear_q_points.linear),
"kv_point_local": LinearParams(ipa.linear_kv_points.linear),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
PointProjectionParams = lambda pp: {
"point_projection": LinearParamsMultimer(
pp.linear,
),
}
IPAParamsMultimer = lambda ipa: {
"q_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_q.weight,
),
},
"k_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_k.weight,
),
},
"v_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_k.weight,
),
},
"q_point_projection": PointProjectionParams(
ipa.linear_q_points
),
"k_point_projection": PointProjectionParams(
ipa.linear_k_points
),
"v_point_projection": PointProjectionParams(
ipa.linear_v_points
),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
......@@ -301,8 +350,10 @@ def import_jax_weights_(model, npz_path, version="model_1"):
ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True)
FoldIterationParams = lambda sm: {
"invariant_point_attention": IPAParams(sm.ipa),
def FoldIterationParams(sm):
d = {
"invariant_point_attention":
IPAParamsMultimer(sm.ipa) if is_multimer else IPAParams(sm.ipa),
"attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
"transition": LinearParams(sm.transition.layers[0].linear_1),
"transition_1": LinearParams(sm.transition.layers[0].linear_2),
......@@ -311,20 +362,33 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"affine_update": LinearParams(sm.bb_update.linear),
"rigid_sidechain": {
"input_projection": LinearParams(sm.angle_resnet.linear_in),
"input_projection_1": LinearParams(sm.angle_resnet.linear_initial),
"input_projection_1":
LinearParams(sm.angle_resnet.linear_initial),
"resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1),
"resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2),
"resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1),
"resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles": LinearParams(sm.angle_resnet.linear_out),
"resblock1_1":
LinearParams(sm.angle_resnet.layers[1].linear_1),
"resblock2_1":
LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles":
LinearParams(sm.angle_resnet.linear_out),
},
}
if(is_multimer):
d.pop("affine_update")
d["quat_rigid"] = {
"rigid": LinearParams(
sm.bb_update.linear
)
}
return d
############################
# translations dict overflow
############################
tps_blocks = model.template_pair_stack.blocks
tps_blocks = model.template_embedder.template_pair_stack.blocks
tps_blocks_params = stacked(
[TemplatePairBlockParams(b) for b in tps_blocks]
)
......@@ -335,6 +399,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
evo_blocks = model.evoformer.blocks
evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])
if(not is_multimer):
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
......@@ -354,26 +419,28 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"template_embedding": {
"single_template_embedding": {
"embedding2d": LinearParams(
model.template_pair_embedder.linear
model.template_embedder.template_pair_embedder.linear
),
"template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm": LayerNormParams(
model.template_pair_stack.layer_norm
model.template_embedder.template_pair_stack.layer_norm
),
},
"attention": AttentionParams(model.template_pointwise_att.mha),
"attention": AttentionParams(
model.template_embedder.template_pointwise_att.mha
),
},
"extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
),
"extra_msa_stack": ems_blocks_params,
"template_single_embedding": LinearParams(
model.template_angle_embedder.linear_1
model.template_embedder.template_angle_embedder.linear_1
),
"template_projection": LinearParams(
model.template_angle_embedder.linear_2
model.template_embedder.template_angle_embedder.linear_2
),
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
......@@ -410,6 +477,123 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
else:
temp_embedder = model.template_embedder
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
"left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear),
"prev_msa_first_row_norm": LayerNormParams(
model.recycling_embedder.layer_norm_m
),
"prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z
),
"~_relative_encoding": {
"position_activations": LinearParams(
model.input_embedder.linear_relpos
),
},
"template_embedding": {
"single_template_embedding": {
"query_embedding_norm": LayerNormParams(
temp_embedder.template_pair_embedder.query_embedding_layer_norm
),
"template_pair_embedding_0": LinearParams(
temp_embedder.template_pair_embedder.dgram_linear
),
"template_pair_embedding_1": LinearParamsMultimer(
temp_embedder.template_pair_embedder.pseudo_beta_mask_linear
),
"template_pair_embedding_2": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_1
),
"template_pair_embedding_3": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_2
),
"template_pair_embedding_4": LinearParamsMultimer(
temp_embedder.template_pair_embedder.x_linear
),
"template_pair_embedding_5": LinearParamsMultimer(
temp_embedder.template_pair_embedder.y_linear
),
"template_pair_embedding_6": LinearParamsMultimer(
temp_embedder.template_pair_embedder.z_linear
),
"template_pair_embedding_7": LinearParamsMultimer(
temp_embedder.template_pair_embedder.backbone_mask_linear
),
"template_pair_embedding_8": LinearParams(
temp_embedder.template_pair_embedder.query_embedding_linear
),
"template_embedding_iteration": tps_blocks_params,
"output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm
),
},
"output_linear": LinearParams(
temp_embedder.linear_t
),
},
"template_projection": LinearParams(
temp_embedder.template_single_embedder.template_projector,
),
"template_single_embedding": LinearParams(
temp_embedder.template_single_embedder.template_single_embedder,
),
"extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
),
"extra_msa_stack": ems_blocks_params,
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
},
"structure_module": {
"single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s
),
"initial_projection": LinearParams(
model.structure_module.linear_in
),
"pair_layer_norm": LayerNormParams(
model.structure_module.layer_norm_z
),
"fold_iteration": FoldIterationParams(model.structure_module),
},
"predicted_lddt_head": {
"input_layer_norm": LayerNormParams(
model.aux_heads.plddt.layer_norm
),
"act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3),
},
"distogram_head": {
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits": LinearParams(
model.aux_heads.experimentally_resolved.linear
),
},
"masked_msa_head": {
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(
model,
is_multimer=("multimer" in version)
)
no_templ = [
"model_3",
......@@ -439,8 +623,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
flat_keys = list(flat.keys())
incorrect = [k for k in flat_keys if k not in keys]
missing = [k for k in keys if k not in flat_keys]
# print(f"Incorrect: {incorrect}")
# print(f"Missing: {missing}")
print(f"Incorrect: {incorrect}")
print(f"Missing: {missing}")
assert len(incorrect) == 0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
......
......@@ -1352,8 +1352,8 @@ class Rigid:
c2_rots[..., 0, 0] = cos_c2
c2_rots[..., 0, 2] = sin_c2
c2_rots[..., 1, 1] = 1
c1_rots[..., 2, 0] = -1 * sin_c2
c1_rots[..., 2, 2] = cos_c2
c2_rots[..., 2, 0] = -1 * sin_c2
c2_rots[..., 2, 2] = cos_c2
c_rots = rot_matmul(c2_rots, c1_rots)
n_xyz = rot_vec_mul(c_rots, n_xyz)
......
......@@ -26,7 +26,12 @@ import time
import torch
from openfold.config import model_config
from openfold.data import templates, feature_pipeline, data_pipeline
from openfold.data import (
data_pipeline,
feature_pipeline,
templates,
)
from openfold.data.tools import hhsearch, hmmsearch
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants, protein
......@@ -49,7 +54,36 @@ def main(args):
#script_preset_(model)
model = model.to(args.model_device)
template_featurizer = templates.TemplateHitFeaturizer(
is_multimer = "multimer" in args.model_name
if(is_multimer):
if(not args.use_precomputed_alignments):
template_searcher = hmmsearch.Hmmsearch(
binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
database_path=args.pdb_seqres_database_path,
)
else:
template_searcher = None
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
else:
if(not args.use_precomputed_alignments):
template_searcher = hhsearch.HHSearch(
binary_path=args.hhsearch_binary_path,
databases=[args.pdb70_database_path],
)
else:
template_searcher = None
template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
......@@ -58,67 +92,96 @@ def main(args):
obsolete_pdbs_path=args.obsolete_pdbs_path
)
use_small_bfd=(args.bfd_database_path is None)
if(not args.use_precomputed_alignments):
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
template_searcher=template_searcher,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus,
)
else:
alignment_runner = None
data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
if(is_multimer):
data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor,
)
output_dir_base = args.output_dir
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
feature_processor = feature_pipeline.FeaturePipeline(
config.data
)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
if(args.use_precomputed_alignments is None):
if(not args.use_precomputed_alignments):
alignment_dir = os.path.join(output_dir_base, "alignments")
else:
alignment_dir = args.use_precomputed_alignments
for fasta_path in os.listdir(args.fasta_dir):
if(not ".fasta" == os.path.splitext(fasta_path)[-1]):
print(f"Skipping {fasta_path}. Not a .fasta file...")
continue
fasta_path = os.path.join(args.fasta_dir, fasta_path)
# Gather input sequences
with open(args.fasta_path, "r") as fp:
lines = [l.strip() for l in fp.readlines()]
with open(fasta_path, "r") as fp:
data = fp.read()
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2]
tags = [l[1:] for l in tags]
for tag, seq in zip(tags, seqs):
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
if((not is_multimer) and len(tags) != 1):
print(
f"{fasta_path} contains more than one sequence but "
f"multimer mode is not enabled. Skipping..."
)
continue
logging.info("Generating features...")
for tag, seq in zip(tags, seqs):
tag, seq = tags[0], seqs[0]
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(
fasta_path, local_alignment_dir
)
if(is_multimer):
local_alignment_dir = alignment_dir
else:
local_alignment_dir = os.path.join(
alignment_dir,
tags[0],
)
feature_dict = data_processor.process_fasta(
fasta_path=fasta_path, alignment_dir=local_alignment_dir
)
# Remove temporary FASTA file
os.remove(fasta_path)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
feature_dict, mode='predict', is_multimer=is_multimer,
)
logging.info("Executing model...")
......@@ -130,6 +193,13 @@ def main(args):
}
t = time.perf_counter()
chunk_size = model.globals.chunk_size
try:
model.globals.chunk_size = None
out = model(batch)
except RuntimeError as e:
model.globals.chunk_size = chunk_size
out = model(batch)
logging.info(f"Inference time: {time.perf_counter() - t}")
......@@ -183,7 +253,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"fasta_path", type=str,
"fasta_dir", type=str,
)
parser.add_argument(
"template_mmcif_dir", type=str,
......
import copy
import os
import torch
import deepspeed
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.ml = torch.nn.ModuleList()
for _ in range(4000):
self.ml.append(torch.nn.Linear(500, 500))
def forward(self, batch):
for i, l in enumerate(self.ml):
# print(f"{i}: {l.weight.device}")
batch = l(batch)
return batch
class DummyDataset(torch.utils.data.Dataset):
def __init__(self):
self.batch = torch.rand(500, 500)
def __getitem__(self, idx):
return copy.deepcopy(self.batch)
def __len__(self):
return 1000
dd = DummyDataset()
dl = torch.utils.data.DataLoader(dd)
example = next(iter(dl)).to(f"cuda:{local_rank}")
model = Model()
model = model.to(f"cuda:{local_rank}")
model = deepspeed.init_inference(
model,
mp_size=world_size,
checkpoint=None,
replace_method=None,
#replace_method="auto"
)
out = model(example)
#if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
# print(out)
......@@ -62,4 +62,10 @@ bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref90..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB SeqRes..."
bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
echo "Downloading UniProt..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
echo "All data downloaded."
......@@ -31,7 +31,7 @@ fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/params"
SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar"
SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the PDB SeqRes database for AlphaFold.
#
# Usage: bash download_pdb_seqres.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/pdb_seqres"
SOURCE_URL="ftp://ftp.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads, unzips and merges the SwissProt and TrEMBL databases for
# AlphaFold-Multimer.
#
# Usage: bash download_uniprot.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/uniprot"
TREMBL_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_trembl.fasta.gz"
TREMBL_BASENAME=$(basename "${TREMBL_SOURCE_URL}")
TREMBL_UNZIPPED_BASENAME="${TREMBL_BASENAME%.gz}"
SPROT_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz"
SPROT_BASENAME=$(basename "${SPROT_SOURCE_URL}")
SPROT_UNZIPPED_BASENAME="${SPROT_BASENAME%.gz}"
mkdir --parents "${ROOT_DIR}"
aria2c "${TREMBL_SOURCE_URL}" --dir="${ROOT_DIR}"
aria2c "${SPROT_SOURCE_URL}" --dir="${ROOT_DIR}"
pushd "${ROOT_DIR}"
gunzip "${ROOT_DIR}/${TREMBL_BASENAME}"
gunzip "${ROOT_DIR}/${SPROT_BASENAME}"
# Concatenate TrEMBL and SwissProt, rename to uniprot and clean up.
cat "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}" >> "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}"
mv "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}" "${ROOT_DIR}/uniprot.fasta"
rm "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}"
popd
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment