Commit 4bd1b4d5 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Work on multimer continues

parent 54164fe8
...@@ -89,20 +89,32 @@ def build_template_angle_feat(template_feats): ...@@ -89,20 +89,32 @@ def build_template_angle_feat(template_feats):
return template_angle_feat return template_angle_feat
def dgram_from_positions(
pos: torch.Tensor,
min_bin: float = 3.25,
max_bin: float = 50.75,
no_bins: float = 39,
inf: float = 1e8,
):
dgram = torch.sum(
(pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
return dgram
def build_template_pair_feat( def build_template_pair_feat(
batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e8 batch, min_bin, max_bin, no_bins, use_unit_vector=False, eps=1e-20, inf=1e8
): ):
template_mask = batch["template_pseudo_beta_mask"] template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
# Compute distogram (this seems to differ slightly from Alg. 5) # Compute distogram (this seems to differ slightly from Alg. 5)
tpb = batch["template_pseudo_beta"] tpb = batch["template_pseudo_beta"]
dgram = torch.sum( dgram = dgram_from_positions(tpb, min_bin, max_bin, no_bins, inf)
(tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
upper = torch.cat([lower[:-1], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
to_concat = [dgram, template_mask_2d[..., None]] to_concat = [dgram, template_mask_2d[..., None]]
...@@ -143,6 +155,10 @@ def build_template_pair_feat( ...@@ -143,6 +155,10 @@ def build_template_pair_feat(
inv_distance_scalar = inv_distance_scalar * template_mask_2d inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = rigid_vec * inv_distance_scalar[..., None] unit_vector = rigid_vec * inv_distance_scalar[..., None]
if(not use_unit_vector):
unit_vector = unit_vector * 0.
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None]) to_concat.append(template_mask_2d[..., None])
...@@ -159,7 +175,7 @@ def build_extra_msa_feat(batch): ...@@ -159,7 +175,7 @@ def build_extra_msa_feat(batch):
batch["extra_has_deletion"].unsqueeze(-1), batch["extra_has_deletion"].unsqueeze(-1),
batch["extra_deletion_value"].unsqueeze(-1), batch["extra_deletion_value"].unsqueeze(-1),
] ]
return torch.cat(msa_feat, dim=-1) return msa_feat
def torsion_angles_to_frames( def torsion_angles_to_frames(
......
...@@ -39,6 +39,13 @@ class ParamType(Enum): ...@@ -39,6 +39,13 @@ class ParamType(Enum):
LinearWeightOPM = partial( LinearWeightOPM = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2) lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
) )
LinearWeightMultimer = partial(
lambda w: w.unsqueeze(-1) if len(w.shape) == 1 else
w.reshape(w.shape[0], -1).transpose(-1, -2)
)
LinearBiasMultimer = partial(
lambda w: w.reshape(-1)
)
Other = partial(lambda w: w) Other = partial(lambda w: w)
def __init__(self, fn): def __init__(self, fn):
...@@ -122,28 +129,32 @@ def assign(translation_dict, orig_weights): ...@@ -122,28 +129,32 @@ def assign(translation_dict, orig_weights):
raise raise
def import_jax_weights_(model, npz_path, version="model_1"): def get_translation_dict(model, is_multimer=False):
data = np.load(npz_path)
####################### #######################
# Some templates # Some templates
####################### #######################
LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight)) LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))
LinearBias = lambda l: (Param(l)) LinearBias = lambda l: (Param(l))
LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA)) LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))
LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA)) LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))
LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM)) LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))
LinearWeightMultimer = lambda l: (
Param(l, param_type=ParamType.LinearWeightMultimer)
)
LinearBiasMultimer = lambda l: (
Param(l, param_type=ParamType.LinearBiasMultimer)
)
LinearParams = lambda l: { LinearParams = lambda l: {
"weights": LinearWeight(l.weight), "weights": LinearWeight(l.weight),
"bias": LinearBias(l.bias), "bias": LinearBias(l.bias),
} }
LinearParamsMultimer = lambda l: {
"weights": LinearWeightMultimer(l.weight),
"bias": LinearBiasMultimer(l.bias),
}
LayerNormParams = lambda l: { LayerNormParams = lambda l: {
"scale": Param(l.weight), "scale": Param(l.weight),
"offset": Param(l.bias), "offset": Param(l.bias),
...@@ -236,10 +247,48 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -236,10 +247,48 @@ def import_jax_weights_(model, npz_path, version="model_1"):
) )
IPAParams = lambda ipa: { IPAParams = lambda ipa: {
"q_scalar": LinearParams(ipa.linear_q), "q_scalar_projection": LinearParams(ipa.linear_q),
"kv_scalar": LinearParams(ipa.linear_kv), "kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points), "q_point_local": LinearParams(ipa.linear_q_points.linear),
"kv_point_local": LinearParams(ipa.linear_kv_points), "kv_point_local": LinearParams(ipa.linear_kv_points.linear),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
PointProjectionParams = lambda pp: {
"point_projection": LinearParamsMultimer(
pp.linear,
),
}
IPAParamsMultimer = lambda ipa: {
"q_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_q.weight,
),
},
"k_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_k.weight,
),
},
"v_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_k.weight,
),
},
"q_point_projection": PointProjectionParams(
ipa.linear_q_points
),
"k_point_projection": PointProjectionParams(
ipa.linear_k_points
),
"v_point_projection": PointProjectionParams(
ipa.linear_v_points
),
"trainable_point_weights": Param( "trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other param=ipa.head_weights, param_type=ParamType.Other
), ),
...@@ -301,8 +350,10 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -301,8 +350,10 @@ def import_jax_weights_(model, npz_path, version="model_1"):
ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True) ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True)
FoldIterationParams = lambda sm: { def FoldIterationParams(sm):
"invariant_point_attention": IPAParams(sm.ipa), d = {
"invariant_point_attention":
IPAParamsMultimer(sm.ipa) if is_multimer else IPAParams(sm.ipa),
"attention_layer_norm": LayerNormParams(sm.layer_norm_ipa), "attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
"transition": LinearParams(sm.transition.layers[0].linear_1), "transition": LinearParams(sm.transition.layers[0].linear_1),
"transition_1": LinearParams(sm.transition.layers[0].linear_2), "transition_1": LinearParams(sm.transition.layers[0].linear_2),
...@@ -311,20 +362,33 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -311,20 +362,33 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"affine_update": LinearParams(sm.bb_update.linear), "affine_update": LinearParams(sm.bb_update.linear),
"rigid_sidechain": { "rigid_sidechain": {
"input_projection": LinearParams(sm.angle_resnet.linear_in), "input_projection": LinearParams(sm.angle_resnet.linear_in),
"input_projection_1": LinearParams(sm.angle_resnet.linear_initial), "input_projection_1":
LinearParams(sm.angle_resnet.linear_initial),
"resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1), "resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1),
"resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2), "resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2),
"resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1), "resblock1_1":
"resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2), LinearParams(sm.angle_resnet.layers[1].linear_1),
"unnormalized_angles": LinearParams(sm.angle_resnet.linear_out), "resblock2_1":
LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles":
LinearParams(sm.angle_resnet.linear_out),
}, },
} }
if(is_multimer):
d.pop("affine_update")
d["quat_rigid"] = {
"rigid": LinearParams(
sm.bb_update.linear
)
}
return d
############################ ############################
# translations dict overflow # translations dict overflow
############################ ############################
tps_blocks = model.template_embedder.template_pair_stack.blocks
tps_blocks = model.template_pair_stack.blocks
tps_blocks_params = stacked( tps_blocks_params = stacked(
[TemplatePairBlockParams(b) for b in tps_blocks] [TemplatePairBlockParams(b) for b in tps_blocks]
) )
...@@ -335,6 +399,7 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -335,6 +399,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
evo_blocks = model.evoformer.blocks evo_blocks = model.evoformer.blocks
evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks]) evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])
if(not is_multimer):
translations = { translations = {
"evoformer": { "evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m), "preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
...@@ -354,26 +419,28 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -354,26 +419,28 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"template_embedding": { "template_embedding": {
"single_template_embedding": { "single_template_embedding": {
"embedding2d": LinearParams( "embedding2d": LinearParams(
model.template_pair_embedder.linear model.template_embedder.template_pair_embedder.linear
), ),
"template_pair_stack": { "template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params, "__layer_stack_no_state": tps_blocks_params,
}, },
"output_layer_norm": LayerNormParams( "output_layer_norm": LayerNormParams(
model.template_pair_stack.layer_norm model.template_embedder.template_pair_stack.layer_norm
), ),
}, },
"attention": AttentionParams(model.template_pointwise_att.mha), "attention": AttentionParams(
model.template_embedder.template_pointwise_att.mha
),
}, },
"extra_msa_activations": LinearParams( "extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear model.extra_msa_embedder.linear
), ),
"extra_msa_stack": ems_blocks_params, "extra_msa_stack": ems_blocks_params,
"template_single_embedding": LinearParams( "template_single_embedding": LinearParams(
model.template_angle_embedder.linear_1 model.template_embedder.template_angle_embedder.linear_1
), ),
"template_projection": LinearParams( "template_projection": LinearParams(
model.template_angle_embedder.linear_2 model.template_embedder.template_angle_embedder.linear_2
), ),
"evoformer_iteration": evo_blocks_params, "evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear), "single_activations": LinearParams(model.evoformer.linear),
...@@ -410,6 +477,123 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -410,6 +477,123 @@ def import_jax_weights_(model, npz_path, version="model_1"):
"logits": LinearParams(model.aux_heads.masked_msa.linear), "logits": LinearParams(model.aux_heads.masked_msa.linear),
}, },
} }
else:
temp_embedder = model.template_embedder
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
"left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear),
"prev_msa_first_row_norm": LayerNormParams(
model.recycling_embedder.layer_norm_m
),
"prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z
),
"~_relative_encoding": {
"position_activations": LinearParams(
model.input_embedder.linear_relpos
),
},
"template_embedding": {
"single_template_embedding": {
"query_embedding_norm": LayerNormParams(
temp_embedder.template_pair_embedder.query_embedding_layer_norm
),
"template_pair_embedding_0": LinearParams(
temp_embedder.template_pair_embedder.dgram_linear
),
"template_pair_embedding_1": LinearParamsMultimer(
temp_embedder.template_pair_embedder.pseudo_beta_mask_linear
),
"template_pair_embedding_2": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_1
),
"template_pair_embedding_3": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_2
),
"template_pair_embedding_4": LinearParamsMultimer(
temp_embedder.template_pair_embedder.x_linear
),
"template_pair_embedding_5": LinearParamsMultimer(
temp_embedder.template_pair_embedder.y_linear
),
"template_pair_embedding_6": LinearParamsMultimer(
temp_embedder.template_pair_embedder.z_linear
),
"template_pair_embedding_7": LinearParamsMultimer(
temp_embedder.template_pair_embedder.backbone_mask_linear
),
"template_pair_embedding_8": LinearParams(
temp_embedder.template_pair_embedder.query_embedding_linear
),
"template_embedding_iteration": tps_blocks_params,
"output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm
),
},
"output_linear": LinearParams(
temp_embedder.linear_t
),
},
"template_projection": LinearParams(
temp_embedder.template_single_embedder.template_projector,
),
"template_single_embedding": LinearParams(
temp_embedder.template_single_embedder.template_single_embedder,
),
"extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
),
"extra_msa_stack": ems_blocks_params,
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
},
"structure_module": {
"single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s
),
"initial_projection": LinearParams(
model.structure_module.linear_in
),
"pair_layer_norm": LayerNormParams(
model.structure_module.layer_norm_z
),
"fold_iteration": FoldIterationParams(model.structure_module),
},
"predicted_lddt_head": {
"input_layer_norm": LayerNormParams(
model.aux_heads.plddt.layer_norm
),
"act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3),
},
"distogram_head": {
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits": LinearParams(
model.aux_heads.experimentally_resolved.linear
),
},
"masked_msa_head": {
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(
model,
is_multimer=("multimer" in version)
)
no_templ = [ no_templ = [
"model_3", "model_3",
...@@ -439,8 +623,8 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -439,8 +623,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
flat_keys = list(flat.keys()) flat_keys = list(flat.keys())
incorrect = [k for k in flat_keys if k not in keys] incorrect = [k for k in flat_keys if k not in keys]
missing = [k for k in keys if k not in flat_keys] missing = [k for k in keys if k not in flat_keys]
# print(f"Incorrect: {incorrect}") print(f"Incorrect: {incorrect}")
# print(f"Missing: {missing}") print(f"Missing: {missing}")
assert len(incorrect) == 0 assert len(incorrect) == 0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys()))) # assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
......
...@@ -1352,8 +1352,8 @@ class Rigid: ...@@ -1352,8 +1352,8 @@ class Rigid:
c2_rots[..., 0, 0] = cos_c2 c2_rots[..., 0, 0] = cos_c2
c2_rots[..., 0, 2] = sin_c2 c2_rots[..., 0, 2] = sin_c2
c2_rots[..., 1, 1] = 1 c2_rots[..., 1, 1] = 1
c1_rots[..., 2, 0] = -1 * sin_c2 c2_rots[..., 2, 0] = -1 * sin_c2
c1_rots[..., 2, 2] = cos_c2 c2_rots[..., 2, 2] = cos_c2
c_rots = rot_matmul(c2_rots, c1_rots) c_rots = rot_matmul(c2_rots, c1_rots)
n_xyz = rot_vec_mul(c_rots, n_xyz) n_xyz = rot_vec_mul(c_rots, n_xyz)
......
...@@ -26,7 +26,12 @@ import time ...@@ -26,7 +26,12 @@ import time
import torch import torch
from openfold.config import model_config from openfold.config import model_config
from openfold.data import templates, feature_pipeline, data_pipeline from openfold.data import (
data_pipeline,
feature_pipeline,
templates,
)
from openfold.data.tools import hhsearch, hmmsearch
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_ from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
...@@ -49,7 +54,36 @@ def main(args): ...@@ -49,7 +54,36 @@ def main(args):
#script_preset_(model) #script_preset_(model)
model = model.to(args.model_device) model = model.to(args.model_device)
template_featurizer = templates.TemplateHitFeaturizer( is_multimer = "multimer" in args.model_name
if(is_multimer):
if(not args.use_precomputed_alignments):
template_searcher = hmmsearch.Hmmsearch(
binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
database_path=args.pdb_seqres_database_path,
)
else:
template_searcher = None
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
else:
if(not args.use_precomputed_alignments):
template_searcher = hhsearch.HHSearch(
binary_path=args.hhsearch_binary_path,
databases=[args.pdb70_database_path],
)
else:
template_searcher = None
template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates, max_hits=config.data.predict.max_templates,
...@@ -58,67 +92,96 @@ def main(args): ...@@ -58,67 +92,96 @@ def main(args):
obsolete_pdbs_path=args.obsolete_pdbs_path obsolete_pdbs_path=args.obsolete_pdbs_path
) )
use_small_bfd=(args.bfd_database_path is None) if(not args.use_precomputed_alignments):
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
template_searcher=template_searcher,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus,
)
else:
alignment_runner = None
data_processor = data_pipeline.DataPipeline( data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
) )
if(is_multimer):
data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor,
)
output_dir_base = args.output_dir output_dir_base = args.output_dir
random_seed = args.data_random_seed random_seed = args.data_random_seed
if random_seed is None: if random_seed is None:
random_seed = random.randrange(sys.maxsize) random_seed = random.randrange(sys.maxsize)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
feature_processor = feature_pipeline.FeaturePipeline(
config.data
)
if not os.path.exists(output_dir_base): if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base) os.makedirs(output_dir_base)
if(args.use_precomputed_alignments is None): if(not args.use_precomputed_alignments):
alignment_dir = os.path.join(output_dir_base, "alignments") alignment_dir = os.path.join(output_dir_base, "alignments")
else: else:
alignment_dir = args.use_precomputed_alignments alignment_dir = args.use_precomputed_alignments
for fasta_path in os.listdir(args.fasta_dir):
if(not ".fasta" == os.path.splitext(fasta_path)[-1]):
print(f"Skipping {fasta_path}. Not a .fasta file...")
continue
fasta_path = os.path.join(args.fasta_dir, fasta_path)
# Gather input sequences # Gather input sequences
with open(args.fasta_path, "r") as fp: with open(fasta_path, "r") as fp:
lines = [l.strip() for l in fp.readlines()] data = fp.read()
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
][1:]
tags, seqs = lines[::2], lines[1::2] tags, seqs = lines[::2], lines[1::2]
tags = [l[1:] for l in tags]
for tag, seq in zip(tags, seqs): if((not is_multimer) and len(tags) != 1):
fasta_path = os.path.join(args.output_dir, "tmp.fasta") print(
with open(fasta_path, "w") as fp: f"{fasta_path} contains more than one sequence but "
fp.write(f">{tag}\n{seq}") f"multimer mode is not enabled. Skipping..."
)
continue
logging.info("Generating features...") for tag, seq in zip(tags, seqs):
tag, seq = tags[0], seqs[0]
local_alignment_dir = os.path.join(alignment_dir, tag) local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None): if(args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir): if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir) os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run( alignment_runner.run(
fasta_path, local_alignment_dir fasta_path, local_alignment_dir
) )
if(is_multimer):
local_alignment_dir = alignment_dir
else:
local_alignment_dir = os.path.join(
alignment_dir,
tags[0],
)
feature_dict = data_processor.process_fasta( feature_dict = data_processor.process_fasta(
fasta_path=fasta_path, alignment_dir=local_alignment_dir fasta_path=fasta_path, alignment_dir=local_alignment_dir
) )
# Remove temporary FASTA file
os.remove(fasta_path)
processed_feature_dict = feature_processor.process_features( processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', feature_dict, mode='predict', is_multimer=is_multimer,
) )
logging.info("Executing model...") logging.info("Executing model...")
...@@ -130,6 +193,13 @@ def main(args): ...@@ -130,6 +193,13 @@ def main(args):
} }
t = time.perf_counter() t = time.perf_counter()
chunk_size = model.globals.chunk_size
try:
model.globals.chunk_size = None
out = model(batch)
except RuntimeError as e:
model.globals.chunk_size = chunk_size
out = model(batch) out = model(batch)
logging.info(f"Inference time: {time.perf_counter() - t}") logging.info(f"Inference time: {time.perf_counter() - t}")
...@@ -183,7 +253,7 @@ def main(args): ...@@ -183,7 +253,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"fasta_path", type=str, "fasta_dir", type=str,
) )
parser.add_argument( parser.add_argument(
"template_mmcif_dir", type=str, "template_mmcif_dir", type=str,
......
import copy
import os
import torch
import deepspeed
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.ml = torch.nn.ModuleList()
for _ in range(4000):
self.ml.append(torch.nn.Linear(500, 500))
def forward(self, batch):
for i, l in enumerate(self.ml):
# print(f"{i}: {l.weight.device}")
batch = l(batch)
return batch
class DummyDataset(torch.utils.data.Dataset):
def __init__(self):
self.batch = torch.rand(500, 500)
def __getitem__(self, idx):
return copy.deepcopy(self.batch)
def __len__(self):
return 1000
dd = DummyDataset()
dl = torch.utils.data.DataLoader(dd)
example = next(iter(dl)).to(f"cuda:{local_rank}")
model = Model()
model = model.to(f"cuda:{local_rank}")
model = deepspeed.init_inference(
model,
mp_size=world_size,
checkpoint=None,
replace_method=None,
#replace_method="auto"
)
out = model(example)
#if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
# print(out)
...@@ -62,4 +62,10 @@ bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}" ...@@ -62,4 +62,10 @@ bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref90..." echo "Downloading Uniref90..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}" bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB SeqRes..."
bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
echo "Downloading UniProt..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
echo "All data downloaded." echo "All data downloaded."
...@@ -31,7 +31,7 @@ fi ...@@ -31,7 +31,7 @@ fi
DOWNLOAD_DIR="$1" DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/params" ROOT_DIR="${DOWNLOAD_DIR}/params"
SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar" SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar"
BASENAME=$(basename "${SOURCE_URL}") BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}" mkdir --parents "${ROOT_DIR}"
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the PDB SeqRes database for AlphaFold.
#
# Usage: bash download_pdb_seqres.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/pdb_seqres"
SOURCE_URL="ftp://ftp.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads, unzips and merges the SwissProt and TrEMBL databases for
# AlphaFold-Multimer.
#
# Usage: bash download_uniprot.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/uniprot"
TREMBL_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_trembl.fasta.gz"
TREMBL_BASENAME=$(basename "${TREMBL_SOURCE_URL}")
TREMBL_UNZIPPED_BASENAME="${TREMBL_BASENAME%.gz}"
SPROT_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz"
SPROT_BASENAME=$(basename "${SPROT_SOURCE_URL}")
SPROT_UNZIPPED_BASENAME="${SPROT_BASENAME%.gz}"
mkdir --parents "${ROOT_DIR}"
aria2c "${TREMBL_SOURCE_URL}" --dir="${ROOT_DIR}"
aria2c "${SPROT_SOURCE_URL}" --dir="${ROOT_DIR}"
pushd "${ROOT_DIR}"
gunzip "${ROOT_DIR}/${TREMBL_BASENAME}"
gunzip "${ROOT_DIR}/${SPROT_BASENAME}"
# Concatenate TrEMBL and SwissProt, rename to uniprot and clean up.
cat "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}" >> "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}"
mv "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}" "${ROOT_DIR}/uniprot.fasta"
rm "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}"
popd
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment