Commit 4bd1b4d5 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Work on multimer continues

parent 54164fe8
...@@ -89,20 +89,32 @@ def build_template_angle_feat(template_feats): ...@@ -89,20 +89,32 @@ def build_template_angle_feat(template_feats):
return template_angle_feat return template_angle_feat
def dgram_from_positions(
pos: torch.Tensor,
min_bin: float = 3.25,
max_bin: float = 50.75,
no_bins: float = 39,
inf: float = 1e8,
):
dgram = torch.sum(
(pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
return dgram
def build_template_pair_feat( def build_template_pair_feat(
batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e8 batch, min_bin, max_bin, no_bins, use_unit_vector=False, eps=1e-20, inf=1e8
): ):
template_mask = batch["template_pseudo_beta_mask"] template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
# Compute distogram (this seems to differ slightly from Alg. 5) # Compute distogram (this seems to differ slightly from Alg. 5)
tpb = batch["template_pseudo_beta"] tpb = batch["template_pseudo_beta"]
dgram = torch.sum( dgram = dgram_from_positions(tpb, min_bin, max_bin, no_bins, inf)
(tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
upper = torch.cat([lower[:-1], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
to_concat = [dgram, template_mask_2d[..., None]] to_concat = [dgram, template_mask_2d[..., None]]
...@@ -143,6 +155,10 @@ def build_template_pair_feat( ...@@ -143,6 +155,10 @@ def build_template_pair_feat(
inv_distance_scalar = inv_distance_scalar * template_mask_2d inv_distance_scalar = inv_distance_scalar * template_mask_2d
unit_vector = rigid_vec * inv_distance_scalar[..., None] unit_vector = rigid_vec * inv_distance_scalar[..., None]
if(not use_unit_vector):
unit_vector = unit_vector * 0.
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
to_concat.append(template_mask_2d[..., None]) to_concat.append(template_mask_2d[..., None])
...@@ -159,7 +175,7 @@ def build_extra_msa_feat(batch): ...@@ -159,7 +175,7 @@ def build_extra_msa_feat(batch):
batch["extra_has_deletion"].unsqueeze(-1), batch["extra_has_deletion"].unsqueeze(-1),
batch["extra_deletion_value"].unsqueeze(-1), batch["extra_deletion_value"].unsqueeze(-1),
] ]
return torch.cat(msa_feat, dim=-1) return msa_feat
def torsion_angles_to_frames( def torsion_angles_to_frames(
......
...@@ -39,6 +39,13 @@ class ParamType(Enum): ...@@ -39,6 +39,13 @@ class ParamType(Enum):
LinearWeightOPM = partial( LinearWeightOPM = partial(
lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2) lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
) )
LinearWeightMultimer = partial(
lambda w: w.unsqueeze(-1) if len(w.shape) == 1 else
w.reshape(w.shape[0], -1).transpose(-1, -2)
)
LinearBiasMultimer = partial(
lambda w: w.reshape(-1)
)
Other = partial(lambda w: w) Other = partial(lambda w: w)
def __init__(self, fn): def __init__(self, fn):
...@@ -122,28 +129,32 @@ def assign(translation_dict, orig_weights): ...@@ -122,28 +129,32 @@ def assign(translation_dict, orig_weights):
raise raise
def import_jax_weights_(model, npz_path, version="model_1"): def get_translation_dict(model, is_multimer=False):
data = np.load(npz_path)
####################### #######################
# Some templates # Some templates
####################### #######################
LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight)) LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))
LinearBias = lambda l: (Param(l)) LinearBias = lambda l: (Param(l))
LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA)) LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))
LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA)) LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))
LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM)) LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))
LinearWeightMultimer = lambda l: (
Param(l, param_type=ParamType.LinearWeightMultimer)
)
LinearBiasMultimer = lambda l: (
Param(l, param_type=ParamType.LinearBiasMultimer)
)
LinearParams = lambda l: { LinearParams = lambda l: {
"weights": LinearWeight(l.weight), "weights": LinearWeight(l.weight),
"bias": LinearBias(l.bias), "bias": LinearBias(l.bias),
} }
LinearParamsMultimer = lambda l: {
"weights": LinearWeightMultimer(l.weight),
"bias": LinearBiasMultimer(l.bias),
}
LayerNormParams = lambda l: { LayerNormParams = lambda l: {
"scale": Param(l.weight), "scale": Param(l.weight),
"offset": Param(l.bias), "offset": Param(l.bias),
...@@ -236,10 +247,48 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -236,10 +247,48 @@ def import_jax_weights_(model, npz_path, version="model_1"):
) )
IPAParams = lambda ipa: { IPAParams = lambda ipa: {
"q_scalar": LinearParams(ipa.linear_q), "q_scalar_projection": LinearParams(ipa.linear_q),
"kv_scalar": LinearParams(ipa.linear_kv), "kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points), "q_point_local": LinearParams(ipa.linear_q_points.linear),
"kv_point_local": LinearParams(ipa.linear_kv_points), "kv_point_local": LinearParams(ipa.linear_kv_points.linear),
"trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other
),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
PointProjectionParams = lambda pp: {
"point_projection": LinearParamsMultimer(
pp.linear,
),
}
IPAParamsMultimer = lambda ipa: {
"q_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_q.weight,
),
},
"k_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_k.weight,
),
},
"v_scalar_projection": {
"weights": LinearWeightMultimer(
ipa.linear_k.weight,
),
},
"q_point_projection": PointProjectionParams(
ipa.linear_q_points
),
"k_point_projection": PointProjectionParams(
ipa.linear_k_points
),
"v_point_projection": PointProjectionParams(
ipa.linear_v_points
),
"trainable_point_weights": Param( "trainable_point_weights": Param(
param=ipa.head_weights, param_type=ParamType.Other param=ipa.head_weights, param_type=ParamType.Other
), ),
...@@ -301,30 +350,45 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -301,30 +350,45 @@ def import_jax_weights_(model, npz_path, version="model_1"):
ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True) ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True)
FoldIterationParams = lambda sm: { def FoldIterationParams(sm):
"invariant_point_attention": IPAParams(sm.ipa), d = {
"attention_layer_norm": LayerNormParams(sm.layer_norm_ipa), "invariant_point_attention":
"transition": LinearParams(sm.transition.layers[0].linear_1), IPAParamsMultimer(sm.ipa) if is_multimer else IPAParams(sm.ipa),
"transition_1": LinearParams(sm.transition.layers[0].linear_2), "attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
"transition_2": LinearParams(sm.transition.layers[0].linear_3), "transition": LinearParams(sm.transition.layers[0].linear_1),
"transition_layer_norm": LayerNormParams(sm.transition.layer_norm), "transition_1": LinearParams(sm.transition.layers[0].linear_2),
"affine_update": LinearParams(sm.bb_update.linear), "transition_2": LinearParams(sm.transition.layers[0].linear_3),
"rigid_sidechain": { "transition_layer_norm": LayerNormParams(sm.transition.layer_norm),
"input_projection": LinearParams(sm.angle_resnet.linear_in), "affine_update": LinearParams(sm.bb_update.linear),
"input_projection_1": LinearParams(sm.angle_resnet.linear_initial), "rigid_sidechain": {
"resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1), "input_projection": LinearParams(sm.angle_resnet.linear_in),
"resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2), "input_projection_1":
"resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1), LinearParams(sm.angle_resnet.linear_initial),
"resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2), "resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1),
"unnormalized_angles": LinearParams(sm.angle_resnet.linear_out), "resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2),
}, "resblock1_1":
} LinearParams(sm.angle_resnet.layers[1].linear_1),
"resblock2_1":
LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles":
LinearParams(sm.angle_resnet.linear_out),
},
}
if(is_multimer):
d.pop("affine_update")
d["quat_rigid"] = {
"rigid": LinearParams(
sm.bb_update.linear
)
}
return d
############################ ############################
# translations dict overflow # translations dict overflow
############################ ############################
tps_blocks = model.template_embedder.template_pair_stack.blocks
tps_blocks = model.template_pair_stack.blocks
tps_blocks_params = stacked( tps_blocks_params = stacked(
[TemplatePairBlockParams(b) for b in tps_blocks] [TemplatePairBlockParams(b) for b in tps_blocks]
) )
...@@ -335,82 +399,202 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -335,82 +399,202 @@ def import_jax_weights_(model, npz_path, version="model_1"):
evo_blocks = model.evoformer.blocks evo_blocks = model.evoformer.blocks
evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks]) evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])
translations = { if(not is_multimer):
"evoformer": { translations = {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m), "evoformer": {
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m), "preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
"left_single": LinearParams(model.input_embedder.linear_tf_z_i), "preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j), "left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear), "right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_msa_first_row_norm": LayerNormParams( "prev_pos_linear": LinearParams(model.recycling_embedder.linear),
model.recycling_embedder.layer_norm_m "prev_msa_first_row_norm": LayerNormParams(
), model.recycling_embedder.layer_norm_m
"prev_pair_norm": LayerNormParams( ),
model.recycling_embedder.layer_norm_z "prev_pair_norm": LayerNormParams(
), model.recycling_embedder.layer_norm_z
"pair_activiations": LinearParams( ),
model.input_embedder.linear_relpos "pair_activiations": LinearParams(
), model.input_embedder.linear_relpos
"template_embedding": { ),
"single_template_embedding": { "template_embedding": {
"embedding2d": LinearParams( "single_template_embedding": {
model.template_pair_embedder.linear "embedding2d": LinearParams(
model.template_embedder.template_pair_embedder.linear
),
"template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm
),
},
"attention": AttentionParams(
model.template_embedder.template_pointwise_att.mha
), ),
"template_pair_stack": { },
"__layer_stack_no_state": tps_blocks_params, "extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
),
"extra_msa_stack": ems_blocks_params,
"template_single_embedding": LinearParams(
model.template_embedder.template_angle_embedder.linear_1
),
"template_projection": LinearParams(
model.template_embedder.template_angle_embedder.linear_2
),
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
},
"structure_module": {
"single_layer_norm": LayerNormParams(
model.structure_module.layer_norm_s
),
"initial_projection": LinearParams(
model.structure_module.linear_in
),
"pair_layer_norm": LayerNormParams(
model.structure_module.layer_norm_z
),
"fold_iteration": FoldIterationParams(model.structure_module),
},
"predicted_lddt_head": {
"input_layer_norm": LayerNormParams(
model.aux_heads.plddt.layer_norm
),
"act_0": LinearParams(model.aux_heads.plddt.linear_1),
"act_1": LinearParams(model.aux_heads.plddt.linear_2),
"logits": LinearParams(model.aux_heads.plddt.linear_3),
},
"distogram_head": {
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits": LinearParams(
model.aux_heads.experimentally_resolved.linear
),
},
"masked_msa_head": {
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
else:
temp_embedder = model.template_embedder
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
"left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear),
"prev_msa_first_row_norm": LayerNormParams(
model.recycling_embedder.layer_norm_m
),
"prev_pair_norm": LayerNormParams(
model.recycling_embedder.layer_norm_z
),
"~_relative_encoding": {
"position_activations": LinearParams(
model.input_embedder.linear_relpos
),
},
"template_embedding": {
"single_template_embedding": {
"query_embedding_norm": LayerNormParams(
temp_embedder.template_pair_embedder.query_embedding_layer_norm
),
"template_pair_embedding_0": LinearParams(
temp_embedder.template_pair_embedder.dgram_linear
),
"template_pair_embedding_1": LinearParamsMultimer(
temp_embedder.template_pair_embedder.pseudo_beta_mask_linear
),
"template_pair_embedding_2": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_1
),
"template_pair_embedding_3": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_2
),
"template_pair_embedding_4": LinearParamsMultimer(
temp_embedder.template_pair_embedder.x_linear
),
"template_pair_embedding_5": LinearParamsMultimer(
temp_embedder.template_pair_embedder.y_linear
),
"template_pair_embedding_6": LinearParamsMultimer(
temp_embedder.template_pair_embedder.z_linear
),
"template_pair_embedding_7": LinearParamsMultimer(
temp_embedder.template_pair_embedder.backbone_mask_linear
),
"template_pair_embedding_8": LinearParams(
temp_embedder.template_pair_embedder.query_embedding_linear
),
"template_embedding_iteration": tps_blocks_params,
"output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm
),
}, },
"output_layer_norm": LayerNormParams( "output_linear": LinearParams(
model.template_pair_stack.layer_norm temp_embedder.linear_t
), ),
}, },
"attention": AttentionParams(model.template_pointwise_att.mha), "template_projection": LinearParams(
temp_embedder.template_single_embedder.template_projector,
),
"template_single_embedding": LinearParams(
temp_embedder.template_single_embedder.template_single_embedder,
),
"extra_msa_activations": LinearParams(
model.extra_msa_embedder.linear
),
"extra_msa_stack": ems_blocks_params,
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer.linear),
}, },
"extra_msa_activations": LinearParams( "structure_module": {
model.extra_msa_embedder.linear "single_layer_norm": LayerNormParams(
), model.structure_module.layer_norm_s
"extra_msa_stack": ems_blocks_params, ),
"template_single_embedding": LinearParams( "initial_projection": LinearParams(
model.template_angle_embedder.linear_1 model.structure_module.linear_in
), ),
"template_projection": LinearParams( "pair_layer_norm": LayerNormParams(
model.template_angle_embedder.linear_2 model.structure_module.layer_norm_z
), ),
"evoformer_iteration": evo_blocks_params, "fold_iteration": FoldIterationParams(model.structure_module),
"single_activations": LinearParams(model.evoformer.linear), },
}, "predicted_lddt_head": {
"structure_module": { "input_layer_norm": LayerNormParams(
"single_layer_norm": LayerNormParams( model.aux_heads.plddt.layer_norm
model.structure_module.layer_norm_s ),
), "act_0": LinearParams(model.aux_heads.plddt.linear_1),
"initial_projection": LinearParams( "act_1": LinearParams(model.aux_heads.plddt.linear_2),
model.structure_module.linear_in "logits": LinearParams(model.aux_heads.plddt.linear_3),
), },
"pair_layer_norm": LayerNormParams( "distogram_head": {
model.structure_module.layer_norm_z "half_logits": LinearParams(model.aux_heads.distogram.linear),
), },
"fold_iteration": FoldIterationParams(model.structure_module), "experimentally_resolved_head": {
}, "logits": LinearParams(
"predicted_lddt_head": { model.aux_heads.experimentally_resolved.linear
"input_layer_norm": LayerNormParams( ),
model.aux_heads.plddt.layer_norm },
), "masked_msa_head": {
"act_0": LinearParams(model.aux_heads.plddt.linear_1), "logits": LinearParams(model.aux_heads.masked_msa.linear),
"act_1": LinearParams(model.aux_heads.plddt.linear_2), },
"logits": LinearParams(model.aux_heads.plddt.linear_3), }
},
"distogram_head": {
"half_logits": LinearParams(model.aux_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits": LinearParams(
model.aux_heads.experimentally_resolved.linear
),
},
"masked_msa_head": {
"logits": LinearParams(model.aux_heads.masked_msa.linear),
},
}
return translations
def import_jax_weights_(model, npz_path, version="model_1"):
data = np.load(npz_path)
translations = get_translation_dict(
model,
is_multimer=("multimer" in version)
)
no_templ = [ no_templ = [
"model_3", "model_3",
"model_4", "model_4",
...@@ -439,8 +623,8 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -439,8 +623,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
flat_keys = list(flat.keys()) flat_keys = list(flat.keys())
incorrect = [k for k in flat_keys if k not in keys] incorrect = [k for k in flat_keys if k not in keys]
missing = [k for k in keys if k not in flat_keys] missing = [k for k in keys if k not in flat_keys]
# print(f"Incorrect: {incorrect}") print(f"Incorrect: {incorrect}")
# print(f"Missing: {missing}") print(f"Missing: {missing}")
assert len(incorrect) == 0 assert len(incorrect) == 0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys()))) # assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
......
...@@ -1352,8 +1352,8 @@ class Rigid: ...@@ -1352,8 +1352,8 @@ class Rigid:
c2_rots[..., 0, 0] = cos_c2 c2_rots[..., 0, 0] = cos_c2
c2_rots[..., 0, 2] = sin_c2 c2_rots[..., 0, 2] = sin_c2
c2_rots[..., 1, 1] = 1 c2_rots[..., 1, 1] = 1
c1_rots[..., 2, 0] = -1 * sin_c2 c2_rots[..., 2, 0] = -1 * sin_c2
c1_rots[..., 2, 2] = cos_c2 c2_rots[..., 2, 2] = cos_c2
c_rots = rot_matmul(c2_rots, c1_rots) c_rots = rot_matmul(c2_rots, c1_rots)
n_xyz = rot_vec_mul(c_rots, n_xyz) n_xyz = rot_vec_mul(c_rots, n_xyz)
......
...@@ -26,7 +26,12 @@ import time ...@@ -26,7 +26,12 @@ import time
import torch import torch
from openfold.config import model_config from openfold.config import model_config
from openfold.data import templates, feature_pipeline, data_pipeline from openfold.data import (
data_pipeline,
feature_pipeline,
templates,
)
from openfold.data.tools import hhsearch, hmmsearch
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_ from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
...@@ -48,79 +53,137 @@ def main(args): ...@@ -48,79 +53,137 @@ def main(args):
import_jax_weights_(model, args.param_path, version=args.model_name) import_jax_weights_(model, args.param_path, version=args.model_name)
#script_preset_(model) #script_preset_(model)
model = model.to(args.model_device) model = model.to(args.model_device)
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
use_small_bfd=(args.bfd_database_path is None) is_multimer = "multimer" in args.model_name
if(is_multimer):
if(not args.use_precomputed_alignments):
template_searcher = hmmsearch.Hmmsearch(
binary_path=args.hmmsearch_binary_path,
hmmbuild_binary_path=args.hmmbuild_binary_path,
database_path=args.pdb_seqres_database_path,
)
else:
template_searcher = None
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
else:
if(not args.use_precomputed_alignments):
template_searcher = hhsearch.HHSearch(
binary_path=args.hhsearch_binary_path,
databases=[args.pdb70_database_path],
)
else:
template_searcher = None
template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
if(not args.use_precomputed_alignments):
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
template_searcher=template_searcher,
use_small_bfd=(args.bfd_database_path is None),
no_cpus=args.cpus,
)
else:
alignment_runner = None
data_processor = data_pipeline.DataPipeline( data_processor = data_pipeline.DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
) )
if(is_multimer):
data_processor = data_pipeline.DataPipelineMultimer(
monomer_data_pipeline=data_processor,
)
output_dir_base = args.output_dir output_dir_base = args.output_dir
random_seed = args.data_random_seed random_seed = args.data_random_seed
if random_seed is None: if random_seed is None:
random_seed = random.randrange(sys.maxsize) random_seed = random.randrange(sys.maxsize)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
feature_processor = feature_pipeline.FeaturePipeline(
config.data
)
if not os.path.exists(output_dir_base): if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base) os.makedirs(output_dir_base)
if(args.use_precomputed_alignments is None): if(not args.use_precomputed_alignments):
alignment_dir = os.path.join(output_dir_base, "alignments") alignment_dir = os.path.join(output_dir_base, "alignments")
else: else:
alignment_dir = args.use_precomputed_alignments alignment_dir = args.use_precomputed_alignments
# Gather input sequences for fasta_path in os.listdir(args.fasta_dir):
with open(args.fasta_path, "r") as fp: if(not ".fasta" == os.path.splitext(fasta_path)[-1]):
lines = [l.strip() for l in fp.readlines()] print(f"Skipping {fasta_path}. Not a .fasta file...")
continue
tags, seqs = lines[::2], lines[1::2]
tags = [l[1:] for l in tags] fasta_path = os.path.join(args.fasta_dir, fasta_path)
for tag, seq in zip(tags, seqs): # Gather input sequences
fasta_path = os.path.join(args.output_dir, "tmp.fasta") with open(fasta_path, "r") as fp:
with open(fasta_path, "w") as fp: data = fp.read()
fp.write(f">{tag}\n{seq}")
lines = [
logging.info("Generating features...") l.replace('\n', '')
local_alignment_dir = os.path.join(alignment_dir, tag) for prot in data.split('>') for l in prot.strip().split('\n', 1)
if(args.use_precomputed_alignments is None): ][1:]
if not os.path.exists(local_alignment_dir): tags, seqs = lines[::2], lines[1::2]
os.makedirs(local_alignment_dir)
if((not is_multimer) and len(tags) != 1):
alignment_runner = data_pipeline.AlignmentRunner( print(
jackhmmer_binary_path=args.jackhmmer_binary_path, f"{fasta_path} contains more than one sequence but "
hhblits_binary_path=args.hhblits_binary_path, f"multimer mode is not enabled. Skipping..."
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
) )
alignment_runner.run( continue
fasta_path, local_alignment_dir
for tag, seq in zip(tags, seqs):
tag, seq = tags[0], seqs[0]
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
alignment_runner.run(
fasta_path, local_alignment_dir
)
if(is_multimer):
local_alignment_dir = alignment_dir
else:
local_alignment_dir = os.path.join(
alignment_dir,
tags[0],
) )
feature_dict = data_processor.process_fasta( feature_dict = data_processor.process_fasta(
fasta_path=fasta_path, alignment_dir=local_alignment_dir fasta_path=fasta_path, alignment_dir=local_alignment_dir
) )
# Remove temporary FASTA file
os.remove(fasta_path)
processed_feature_dict = feature_processor.process_features( processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', feature_dict, mode='predict', is_multimer=is_multimer,
) )
logging.info("Executing model...") logging.info("Executing model...")
batch = processed_feature_dict batch = processed_feature_dict
with torch.no_grad(): with torch.no_grad():
...@@ -130,9 +193,16 @@ def main(args): ...@@ -130,9 +193,16 @@ def main(args):
} }
t = time.perf_counter() t = time.perf_counter()
out = model(batch)
chunk_size = model.globals.chunk_size
try:
model.globals.chunk_size = None
out = model(batch)
except RuntimeError as e:
model.globals.chunk_size = chunk_size
out = model(batch)
logging.info(f"Inference time: {time.perf_counter() - t}") logging.info(f"Inference time: {time.perf_counter() - t}")
# Toss out the recycling dimensions --- we don't need them anymore # Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out) out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
...@@ -143,7 +213,7 @@ def main(args): ...@@ -143,7 +213,7 @@ def main(args):
plddt_b_factors = np.repeat( plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1 plddt[..., None], residue_constants.atom_type_num, axis=-1
) )
unrelaxed_protein = protein.from_prediction( unrelaxed_protein = protein.from_prediction(
features=batch, features=batch,
result=out, result=out,
...@@ -183,7 +253,7 @@ def main(args): ...@@ -183,7 +253,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"fasta_path", type=str, "fasta_dir", type=str,
) )
parser.add_argument( parser.add_argument(
"template_mmcif_dir", type=str, "template_mmcif_dir", type=str,
......
import copy
import os
import torch
import deepspeed
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.ml = torch.nn.ModuleList()
for _ in range(4000):
self.ml.append(torch.nn.Linear(500, 500))
def forward(self, batch):
for i, l in enumerate(self.ml):
# print(f"{i}: {l.weight.device}")
batch = l(batch)
return batch
class DummyDataset(torch.utils.data.Dataset):
def __init__(self):
self.batch = torch.rand(500, 500)
def __getitem__(self, idx):
return copy.deepcopy(self.batch)
def __len__(self):
return 1000
dd = DummyDataset()
dl = torch.utils.data.DataLoader(dd)
example = next(iter(dl)).to(f"cuda:{local_rank}")
model = Model()
model = model.to(f"cuda:{local_rank}")
model = deepspeed.init_inference(
model,
mp_size=world_size,
checkpoint=None,
replace_method=None,
#replace_method="auto"
)
out = model(example)
#if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
# print(out)
...@@ -62,4 +62,10 @@ bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}" ...@@ -62,4 +62,10 @@ bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref90..." echo "Downloading Uniref90..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}" bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB SeqRes..."
bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
echo "Downloading UniProt..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
echo "All data downloaded." echo "All data downloaded."
...@@ -31,7 +31,7 @@ fi ...@@ -31,7 +31,7 @@ fi
DOWNLOAD_DIR="$1" DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/params" ROOT_DIR="${DOWNLOAD_DIR}/params"
SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar" SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar"
BASENAME=$(basename "${SOURCE_URL}") BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}" mkdir --parents "${ROOT_DIR}"
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the PDB SeqRes database for AlphaFold.
#
# Usage: bash download_pdb_seqres.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/pdb_seqres"
SOURCE_URL="ftp://ftp.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads, unzips and merges the SwissProt and TrEMBL databases for
# AlphaFold-Multimer.
#
# Usage: bash download_uniprot.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/uniprot"
TREMBL_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_trembl.fasta.gz"
TREMBL_BASENAME=$(basename "${TREMBL_SOURCE_URL}")
TREMBL_UNZIPPED_BASENAME="${TREMBL_BASENAME%.gz}"
SPROT_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz"
SPROT_BASENAME=$(basename "${SPROT_SOURCE_URL}")
SPROT_UNZIPPED_BASENAME="${SPROT_BASENAME%.gz}"
mkdir --parents "${ROOT_DIR}"
aria2c "${TREMBL_SOURCE_URL}" --dir="${ROOT_DIR}"
aria2c "${SPROT_SOURCE_URL}" --dir="${ROOT_DIR}"
pushd "${ROOT_DIR}"
gunzip "${ROOT_DIR}/${TREMBL_BASENAME}"
gunzip "${ROOT_DIR}/${SPROT_BASENAME}"
# Concatenate TrEMBL and SwissProt, rename to uniprot and clean up.
cat "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}" >> "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}"
mv "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}" "${ROOT_DIR}/uniprot.fasta"
rm "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}"
popd
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment