Commit fdcb72e8 authored by Christina Floristean's avatar Christina Floristean
Browse files

Bug fixes for multimer inference and monomer training

parent 51556d52
......@@ -155,14 +155,17 @@ def model_config(
c.loss.tm.weight = 0.1
elif "multimer" in name:
c.globals.is_multimer = True
c.globals.bfloat16 = True
c.globals.bfloat16 = False
c.globals.bfloat16_output = False
c.loss.masked_msa.num_classes = 22
c.data.common.max_recycling_iters = 20
for k,v in multimer_model_config_update.items():
for k,v in multimer_model_config_update['model'].items():
c.model[k] = v
for k, v in multimer_model_config_update['loss'].items():
c.loss[k] = v
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
#c.model.input_embedder.num_msa = 252
......@@ -590,6 +593,12 @@ config = mlc.ConfigDict(
"c_out": 37,
},
},
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `max_recycling_iters` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
"recycle_early_stop_tolerance": -1.
},
"relax": {
"max_iterations": 0, # no max
......@@ -670,157 +679,154 @@ config = mlc.ConfigDict(
"eps": eps,
},
"ema": {"decay": 0.999},
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `max_recycling_iters` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
"recycle_early_stop_tolerance": -1
}
)
multimer_model_config_update = {
"input_embedder": {
"tf_dim": 21,
"msa_dim": 49,
#"num_msa": 508,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
"max_relative_chain": 2,
"max_relative_idx": 32,
"use_chain_relative": True,
},
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_pair_embedder": {
"model": {
"input_embedder": {
"tf_dim": 21,
"msa_dim": 49,
#"num_msa": 508,
"c_z": c_z,
"c_out": 64,
"c_dgram": 39,
"c_aatype": 22,
},
"template_single_embedder": {
"c_in": 34,
"c_m": c_m,
"relpos_k": 32,
"max_relative_chain": 2,
"max_relative_idx": 32,
"use_chain_relative": True,
},
"template_pair_stack": {
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_pair_embedder": {
"c_z": c_z,
"c_out": 64,
"c_dgram": 39,
"c_aatype": 22,
},
"template_single_embedder": {
"c_in": 34,
"c_m": c_m,
},
"template_pair_stack": {
"c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"tri_mul_first": True,
"fuse_projection_weights": True,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
},
"c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"tri_mul_first": True,
"fuse_projection_weights": True,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
"c_z": c_z,
"inf": 1e5, # 1e9,
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
"use_unit_vector": True
},
"c_t": c_t,
"c_z": c_z,
"inf": 1e5, # 1e9,
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
"use_unit_vector": True
},
"extra_msa": {
"extra_msa_embedder": {
"c_in": 25,
"c_out": c_e,
#"num_extra_msa": 2048
"extra_msa": {
"extra_msa_embedder": {
"c_in": 25,
"c_out": c_e,
#"num_extra_msa": 2048
},
"extra_msa_stack": {
"c_m": c_e,
"c_z": c_z,
"c_hidden_msa_att": 8,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 4,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True,
"fuse_projection_weights": True,
"clear_cache_between_blocks": True,
"inf": 1e9,
"eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
},
"enabled": True,
},
"extra_msa_stack": {
"c_m": c_e,
"evoformer_stack": {
"c_m": c_m,
"c_z": c_z,
"c_hidden_msa_att": 8,
"c_hidden_msa_att": 32,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"c_s": c_s,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 4,
"no_blocks": 48,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True,
"fuse_projection_weights": True,
"clear_cache_between_blocks": True,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9,
"eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
},
"enabled": True,
},
"evoformer_stack": {
"c_m": c_m,
"c_z": c_z,
"c_hidden_msa_att": 32,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"c_s": c_s,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 48,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True,
"fuse_projection_weights": True,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9,
"eps": eps, # 1e-10,
},
"structure_module": {
"c_s": c_s,
"c_z": c_z,
"c_ipa": 16,
"c_resnet": 128,
"no_heads_ipa": 12,
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 20,
"epsilon": eps, # 1e-12,
"inf": 1e5,
},
"heads": {
"lddt": {
"no_bins": 50,
"c_in": c_s,
"c_hidden": 128,
},
"distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
},
"tm": {
"structure_module": {
"c_s": c_s,
"c_z": c_z,
"no_bins": aux_distogram_bins,
"ptm_weight": 0.2,
"iptm_weight": 0.8,
"enabled": True,
"c_ipa": 16,
"c_resnet": 128,
"no_heads_ipa": 12,
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 20,
"epsilon": eps, # 1e-12,
"inf": 1e5,
},
"masked_msa": {
"c_m": c_m,
"c_out": 22,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
"heads": {
"lddt": {
"no_bins": 50,
"c_in": c_s,
"c_hidden": 128,
},
"distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
},
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"ptm_weight": 0.2,
"iptm_weight": 0.8,
"enabled": True,
},
"masked_msa": {
"c_m": c_m,
"c_out": 22,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
},
},
"recycle_early_stop_tolerance": 0.5
},
"loss": {
"distogram": {
......@@ -897,6 +903,5 @@ multimer_model_config_update = {
"enabled": True,
},
"eps": eps,
},
"recycle_early_stop_tolerance": 0.5
}
}
......@@ -151,7 +151,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
chain: i for i, chain in enumerate(self._chain_ids)
}
template_featurizer = templates.TemplateHitFeaturizer(
template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date,
max_hits=max_template_hits,
......
......@@ -24,7 +24,7 @@ from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features
from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein
......@@ -34,22 +34,10 @@ FeatureDict = MutableMapping[str, np.ndarray]
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
def empty_template_feats(n_res) -> FeatureDict:
return {
"template_aatype": np.zeros((0, n_res)).astype(np.int64),
"template_all_atom_positions":
np.zeros((0, n_res, 37, 3)).astype(np.float32),
"template_sum_probs": np.zeros((0, 1)).astype(np.float32),
"template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
}
def make_template_features(
input_sequence: str,
hits: Sequence[Any],
template_featurizer: Any,
query_pdb_code: Optional[str] = None,
query_release_date: Optional[str] = None,
) -> FeatureDict:
hits_cat = sum(hits.values(), [])
if(len(hits_cat) == 0 or template_featurizer is None):
......@@ -61,11 +49,6 @@ def make_template_features(
)
template_features = templates_result.features
# The template featurizer doesn't format empty template features
# properly. This is a quick fix.
if(template_features["template_aatype"].shape[0] == 0):
template_features = empty_template_feats(len(input_sequence))
return template_features
......@@ -453,7 +436,8 @@ class AlignmentRunner:
if(uniprot_database_path is not None):
self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniprot_database_path
database_path=uniprot_database_path,
n_cpu=no_cpus
)
if(template_searcher is not None and
......@@ -800,37 +784,6 @@ class DataPipeline:
return all_hits
def _parse_template_hits(
self,
alignment_dir: str,
alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
all_hits = {}
if (alignment_index is not None):
fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if (ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if (ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
def _get_msas(self,
alignment_dir: str,
input_sequence: Optional[str] = None,
......@@ -935,15 +888,15 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(
hits = self._parse_template_hit_files(
alignment_dir,
input_sequence,
alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
query_release_date=to_date(mmcif.header["release_date"])
self.template_featurizer
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
......@@ -984,8 +937,9 @@ class DataPipeline:
is_distillation=is_distillation
)
hits = self._parse_template_hits(
hits = self._parse_template_hit_files(
alignment_dir,
input_sequence,
alignment_index
)
......@@ -1016,8 +970,9 @@ class DataPipeline:
description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(
hits = self._parse_template_hit_files(
alignment_dir,
input_sequence,
alignment_index
)
......@@ -1107,7 +1062,7 @@ class DataPipeline:
alignment_dir = os.path.join(
super_alignment_dir, desc
)
hits = self._parse_template_hits(alignment_dir, alignment_index=None)
hits = self._parse_template_hit_files(alignment_dir, seq, alignment_index=None)
template_features = make_template_features(
seq,
hits,
......
......@@ -89,18 +89,17 @@ def make_all_atom_aatype(protein):
def fix_templates_aatype(protein):
# Map one-hot to indices
num_templates = protein["template_aatype"].shape[0]
if(num_templates > 0):
protein["template_aatype"] = torch.argmax(
protein["template_aatype"], dim=-1
)
# Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(
new_order_list, dtype=torch.int64, device=protein["template_aatype"].device,
).expand(num_templates, -1)
protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"]
)
protein["template_aatype"] = torch.argmax(
protein["template_aatype"], dim=-1
)
# Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(
new_order_list, dtype=torch.int64, device=protein["template_aatype"].device,
).expand(num_templates, -1)
protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"]
)
return protein
......
......@@ -2,7 +2,9 @@ from typing import Sequence
import torch
from openfold.config import NUM_RES
from openfold.data.data_transforms import curry1
from openfold.np import residue_constants as rc
from openfold.utils.tensor_utils import masked_mean
......@@ -301,3 +303,177 @@ def make_msa_profile(batch):
)
return batch
def get_interface_residues(positions, atom_mask, asym_id, interface_threshold):
coord_diff = positions[..., None, :, :] - positions[..., None, :, :, :]
pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))
diff_chain_mask = (asym_id[..., None, :] != asym_id[..., :, None]).float()
pair_mask = atom_mask[..., None, :] * atom_mask[..., None, :, :]
mask = diff_chain_mask[..., None] * pair_mask
min_dist_per_res = torch.where(mask, pairwise_dists, torch.inf).min(dim=-1)
valid_interfaces = torch.sum((min_dist_per_res < interface_threshold).float(), dim=-1)
interface_residues_idxs = torch.nonzero(valid_interfaces, as_tuple=True)[0]
return interface_residues_idxs
def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
positions = protein["all_atom_positions"]
atom_mask = protein["all_atom_mask"]
asym_id = protein["asym_id"]
interface_residues = get_interface_residues(positions=positions,
atom_mask=atom_mask,
asym_id=asym_id,
interface_threshold=interface_threshold)
if not torch.any(interface_residues):
return get_contiguous_crop_idx(protein, crop_size, generator)
target_res = interface_residues[int(torch.randint(0, interface_residues.shape[-1], (1,),
device=positions.device, generator=generator)[0])]
ca_idx = rc.atom_order["CA"]
ca_positions = positions[..., ca_idx, :]
ca_mask = atom_mask[..., ca_idx].bool()
coord_diff = ca_positions[..., None, :] - ca_positions[..., None, :, :]
ca_pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))
to_target_distances = ca_pairwise_dists[target_res]
break_tie = (
torch.arange(
0, to_target_distances.shape[-1], device=positions.device
).float()
* 1e-3
)
to_target_distances = torch.where(ca_mask[..., None], to_target_distances, torch.inf) + break_tie
ret = torch.argsort(to_target_distances)[:crop_size]
return ret.sort().values
def randint(lower, upper, generator, device):
return int(torch.randint(
lower,
upper + 1,
(1,),
device=device,
generator=generator,
)[0])
def get_contiguous_crop_idx(protein, crop_size, generator):
num_res = protein["aatype"].shape[0]
if num_res <= crop_size:
return torch.arange(num_res)
_, chain_lens = protein["asym_id"].unique(return_counts=True)
shuffle_idx = torch.randperm(chain_lens.shape[-1], device=chain_lens.device, generator=generator)
num_remaining = int(chain_lens.sum())
num_budget = crop_size
crop_idxs = []
asym_offset = torch.tensor(0, dtype=torch.int64)
for j, idx in enumerate(shuffle_idx):
this_len = int(chain_lens[idx])
num_remaining -= this_len
# num res at most we can keep in this ent
crop_size_max = min(num_budget, this_len)
# num res at least we shall keep in this ent
crop_size_min = min(this_len, max(0, num_budget - num_remaining))
chain_crop_size = randint(lower=crop_size_min,
upper=crop_size_max + 1,
generator=generator,
device=chain_lens.device)
chain_start = randint(lower=0,
upper=this_len - chain_crop_size + 1,
generator=generator,
device=chain_lens.device)
crop_idxs.append(
torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size)
)
asym_offset += this_len
num_budget -= chain_crop_size
return torch.concat(crop_idxs)
@curry1
def random_crop_to_size(
protein,
crop_size,
max_templates,
shape_schema,
spatial_crop_prob,
interface_threshold,
subsample_templates=False,
seed=None,
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
# We want each ensemble to be cropped the same way
g = torch.Generator(device=protein["seq_length"].device)
if seed is not None:
g.manual_seed(seed)
use_spatial_crop = torch.rand((1,),
device=protein["seq_length"].device,
generator=g) < spatial_crop_prob
if use_spatial_crop:
crop_idxs = get_spatial_crop_idx(protein, crop_size, interface_threshold, g)
else:
crop_idxs = get_contiguous_crop_idx(protein, crop_size, g)
if "template_mask" in protein:
num_templates = protein["template_mask"].shape[-1]
else:
num_templates = 0
# No need to subsample templates if there aren't any
subsample_templates = subsample_templates and num_templates
if subsample_templates:
templates_crop_start = randint(lower=0,
upper=num_templates + 1,
generator=g,
device=protein["seq_length"].device)
templates_select_indices = torch.randperm(
num_templates, device=protein["seq_length"].device, generator=g
)
else:
templates_crop_start = 0
num_res_crop_size = min(int(protein["seq_length"]), crop_size)
num_templates_crop_size = min(
num_templates - templates_crop_start, max_templates
)
for k, v in protein.items():
if k not in shape_schema or (
"template" not in k and NUM_RES not in shape_schema[k]
):
continue
# randomly permute the templates before cropping them.
if k.startswith("template") and subsample_templates:
v = v[templates_select_indices]
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
is_num_res = dim_size == NUM_RES
if i == 0 and k.startswith("template"):
crop_size = num_templates_crop_size
crop_start = templates_crop_start
v = v[slice(crop_start, crop_start + crop_size)]
elif is_num_res:
v = torch.index_select(v, i, crop_idxs)
protein[k] = v
protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
return protein
......@@ -104,7 +104,8 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
# the masked locations and secret corrupted locations.
transforms.append(
data_transforms.make_masked_msa(
common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction
common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction,
seed=(msa_seed + 1) if msa_seed else None,
)
)
......
......@@ -89,6 +89,24 @@ TEMPLATE_FEATURES = {
}
def empty_template_feats(n_res):
return {
"template_aatype": np.zeros(
(0, n_res, len(residue_constants.restypes_with_x_and_gap)),
np.float32
),
"template_all_atom_mask": np.zeros(
(0, n_res, residue_constants.atom_type_num), np.float32
),
"template_all_atom_positions": np.zeros(
(0, n_res, residue_constants.atom_type_num, 3), np.float32
),
"template_domain_names": np.array([''.encode()], dtype=np.object),
"template_sequence": np.array([''.encode()], dtype=np.object),
"template_sum_probs": np.zeros((0, 1), dtype=np.float32),
}
def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
"""Returns PDB id and chain id for an HHSearch Hit."""
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
......@@ -1163,21 +1181,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
else:
num_res = len(query_sequence)
# Construct a default template with all zeros.
template_features = {
"template_aatype": np.zeros(
(1, num_res, len(residue_constants.restypes_with_x_and_gap)),
np.float32
),
"template_all_atom_masks": np.zeros(
(1, num_res, residue_constants.atom_type_num), np.float32
),
"template_all_atom_positions": np.zeros(
(1, num_res, residue_constants.atom_type_num, 3), np.float32
),
"template_domain_names": np.array([''.encode()], dtype=np.object),
"template_sequence": np.array([''.encode()], dtype=np.object),
"template_sum_probs": np.array([0], dtype=np.float32),
}
template_features = empty_template_feats(num_res)
return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings
......@@ -1276,21 +1280,7 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
else:
num_res = len(query_sequence)
# Construct a default template with all zeros.
template_features = {
"template_aatype": np.zeros(
(1, num_res, len(residue_constants.restypes_with_x_and_gap)),
np.float32
),
"template_all_atom_masks": np.zeros(
(1, num_res, residue_constants.atom_type_num), np.float32
),
"template_all_atom_positions": np.zeros(
(1, num_res, residue_constants.atom_type_num, 3), np.float32
),
"template_domain_names": np.array([''.encode()], dtype=np.object),
"template_sequence": np.array([''.encode()], dtype=np.object),
"template_sum_probs": np.array([0], dtype=np.float32),
}
template_features = empty_template_feats(num_res)
return TemplateSearchResult(
features=template_features,
......
......@@ -242,7 +242,7 @@ class InputEmbedderMultimer(nn.Module):
entity_id = batch["entity_id"]
entity_id_same = (entity_id[..., None] == entity_id[..., None, :])
rel_feats.append(entity_id_same[..., None])
rel_feats.append(entity_id_same[..., None].to(dtype=rel_pos.dtype))
sym_id = batch["sym_id"]
rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
......@@ -577,7 +577,7 @@ class TemplateEmbedder(nn.Module):
# a second copy during the stack later on
t_pair = z.new_zeros(
z.shape[:-3] +
(n_templ, n, n, self.config.template_pair_embedder.c_t)
(n_templ, n, n, self.config.template_pair_embedder.c_out)
)
for i in range(n_templ):
......@@ -667,17 +667,17 @@ class TemplatePairEmbedderMultimer(nn.Module):
):
super(TemplatePairEmbedderMultimer, self).__init__()
self.dgram_linear = Linear(c_dgram, c_out)
self.aatype_linear_1 = Linear(c_aatype, c_out)
self.aatype_linear_2 = Linear(c_aatype, c_out)
self.dgram_linear = Linear(c_dgram, c_out, init='relu')
self.aatype_linear_1 = Linear(c_aatype, c_out, init='relu')
self.aatype_linear_2 = Linear(c_aatype, c_out, init='relu')
self.query_embedding_layer_norm = LayerNorm(c_z)
self.query_embedding_linear = Linear(c_z, c_out)
self.query_embedding_linear = Linear(c_z, c_out, init='relu')
self.pseudo_beta_mask_linear = Linear(1, c_out)
self.x_linear = Linear(1, c_out)
self.y_linear = Linear(1, c_out)
self.z_linear = Linear(1, c_out)
self.backbone_mask_linear = Linear(1, c_out)
self.pseudo_beta_mask_linear = Linear(1, c_out, init='relu')
self.x_linear = Linear(1, c_out, init='relu')
self.y_linear = Linear(1, c_out, init='relu')
self.z_linear = Linear(1, c_out, init='relu')
self.backbone_mask_linear = Linear(1, c_out, init='relu')
def forward(self,
template_dgram: torch.Tensor,
......@@ -812,10 +812,10 @@ class TemplateEmbedderMultimer(nn.Module):
single_template_embeds = {}
act = 0.
template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"],
single_template_feats["template_pseudo_beta_mask"],
)
template_positions, pseudo_beta_mask = pseudo_beta_fn(
single_template_feats["template_aatype"],
single_template_feats["template_all_atom_positions"],
single_template_feats["template_all_atom_mask"])
template_dgram = dgram_from_positions(
template_positions,
......
......@@ -186,11 +186,6 @@ class AlphaFold(nn.Module):
if self.config.recycle_early_stop_tolerance < 0:
return False
if no_batch_dims == 0:
prev_pos = prev_pos.unsqueeze(dim=0)
next_pos = next_pos.unsqueeze(dim=0)
mask = mask.unsqueeze(dim=0)
ca_idx = residue_constants.atom_order['CA']
sq_diff = (distances(prev_pos[..., ca_idx, :]) - distances(next_pos[..., ca_idx, :])) ** 2
mask = mask[..., None] * mask[..., None, :]
......@@ -265,7 +260,7 @@ class AlphaFold(nn.Module):
requires_grad=False,
)
x_prev = pseudo_beta_fn(
pseudo_beta_x_prev = pseudo_beta_fn(
feats["aatype"], x_prev, None
).to(dtype=z.dtype)
......@@ -279,10 +274,12 @@ class AlphaFold(nn.Module):
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
pseudo_beta_x_prev,
inplace_safe=inplace_safe,
)
del pseudo_beta_x_prev
if(self.globals.offload_inference and inplace_safe):
m = m.to(m_1_prev_emb.device)
z = z.to(z_prev.device)
......
......@@ -166,12 +166,14 @@ class PointProjection(nn.Module):
c_hidden: int,
num_points: int,
no_heads: int,
is_multimer: bool,
return_local_points: bool = False,
):
super().__init__()
self.return_local_points = return_local_points
self.no_heads = no_heads
self.num_points = num_points
self.is_multimer = is_multimer
self.linear = Linear(c_hidden, no_heads * 3 * num_points)
......@@ -181,24 +183,19 @@ class PointProjection(nn.Module):
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO: Needs to run in high precision during training
points_local = self.linear(activations)
out_shape = points_local.shape[:-1] + (self.no_heads, self.num_points, 3)
if isinstance(rigids, Rigid3Array):
points_local = points_local.reshape(
*points_local.shape[:-1],
self.no_heads,
-1,
if self.is_multimer:
points_local = points_local.view(
points_local.shape[:-1] + (self.no_heads, -1)
)
points_local = torch.split(
points_local, points_local.shape[-1] // 3, dim=-1
)
points_local = torch.stack(points_local, dim=-1)
points_local = torch.stack(points_local, dim=-1).view(out_shape)
if not isinstance(rigids, Rigid3Array):
points_local = points_local.reshape(
*points_local.shape[:-2], self.no_heads, -1, 3
)
points_global = rigids[..., None, None].apply(points_local)
if(self.return_local_points):
......@@ -260,7 +257,8 @@ class InvariantPointAttention(nn.Module):
self.linear_q_points = PointProjection(
self.c_s,
self.no_qk_points,
self.no_heads
self.no_heads,
self.is_multimer
)
if(is_multimer):
......@@ -270,12 +268,14 @@ class InvariantPointAttention(nn.Module):
self.c_s,
self.no_qk_points,
self.no_heads,
self.is_multimer
)
self.linear_v_points = PointProjection(
self.c_s,
self.no_v_points,
self.no_heads,
self.is_multimer
)
else:
self.linear_kv = Linear(self.c_s, 2 * hc)
......@@ -283,6 +283,7 @@ class InvariantPointAttention(nn.Module):
self.c_s,
self.no_qk_points + self.no_v_points,
self.no_heads,
self.is_multimer
)
self.linear_b = Linear(self.c_z, self.no_heads)
......@@ -504,6 +505,230 @@ class InvariantPointAttention(nn.Module):
return s
#TODO: This module follows the refactoring done in IPA for multimer. Running the regular IPA above
# in multimer mode should be equivalent, but tests do not pass unless using this version. Determine
# whether or not the increase in test error matters in practice.
class InvariantPointAttentionMultimer(nn.Module):
"""
Implements Algorithm 22.
"""
def __init__(
self,
c_s: int,
c_z: int,
c_hidden: int,
no_heads: int,
no_qk_points: int,
no_v_points: int,
inf: float = 1e5,
eps: float = 1e-8,
is_multimer: bool = True,
):
"""
Args:
c_s:
Single representation channel dimension
c_z:
Pair representation channel dimension
c_hidden:
Hidden channel dimension
no_heads:
Number of attention heads
no_qk_points:
Number of query/key points to generate
no_v_points:
Number of value points to generate
"""
super(InvariantPointAttentionMultimer, self).__init__()
self.c_s = c_s
self.c_z = c_z
self.c_hidden = c_hidden
self.no_heads = no_heads
self.no_qk_points = no_qk_points
self.no_v_points = no_v_points
self.inf = inf
self.eps = eps
# These linear layers differ from their specifications in the
# supplement. There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default
# Lecun initialization.
hc = self.c_hidden * self.no_heads
self.linear_q = Linear(self.c_s, hc, bias=False)
self.linear_q_points = PointProjection(
self.c_s,
self.no_qk_points,
self.no_heads,
is_multimer=True
)
self.linear_k = Linear(self.c_s, hc, bias=False)
self.linear_v = Linear(self.c_s, hc, bias=False)
self.linear_k_points = PointProjection(
self.c_s,
self.no_qk_points,
self.no_heads,
is_multimer=True
)
self.linear_v_points = PointProjection(
self.c_s,
self.no_v_points,
self.no_heads,
is_multimer=True
)
self.linear_b = Linear(self.c_z, self.no_heads)
self.head_weights = nn.Parameter(torch.zeros((no_heads)))
ipa_point_weights_init_(self.head_weights)
concat_out_dim = self.no_heads * (
self.c_z + self.c_hidden + self.no_v_points * 4
)
self.linear_out = Linear(concat_out_dim, self.c_s, init="final")
self.softmax = nn.Softmax(dim=-2)
def forward(
self,
s: torch.Tensor,
z: Optional[torch.Tensor],
r: Union[Rigid, Rigid3Array],
mask: torch.Tensor,
inplace_safe: bool = False,
_offload_inference: bool = False,
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
) -> torch.Tensor:
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
r:
[*, N_res] transformation object
mask:
[*, N_res] mask
Returns:
[*, N_res, C_s] single representation update
"""
if(_offload_inference and inplace_safe):
z = _z_reference_list
else:
z = [z]
a = 0.
point_variance = (max(self.no_qk_points, 1) * 9.0 / 2)
point_weights = math.sqrt(1.0 / point_variance)
softplus = lambda x: torch.logaddexp(x, torch.zeros_like(x))
head_weights = softplus(self.head_weights)
point_weights = point_weights * head_weights
#######################################
# Generate scalar and point activations
#######################################
# [*, N_res, H, P_qk]
q_pts = Vec3Array.from_array(self.linear_q_points(s, r))
# [*, N_res, H, P_qk, 3]
k_pts = Vec3Array.from_array(self.linear_k_points(s, r))
pt_att = square_euclidean_distance(q_pts.unsqueeze(-3), k_pts.unsqueeze(-4), epsilon=0.)
pt_att = torch.sum(pt_att * point_weights[..., None], dim=-1) * (-0.5)
a = a + pt_att
scalar_variance = max(self.c_hidden, 1) * 1.
scalar_weights = math.sqrt(1.0 / scalar_variance)
# [*, N_res, H * C_hidden]
q = self.linear_q(s)
k = self.linear_k(s)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
q = q * scalar_weights
a = a + torch.einsum('...qhc,...khc->...qkh', q, k)
##########################
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b = self.linear_b(z[0])
if (_offload_inference):
assert (sys.getrefcount(z[0]) == 2)
z[0] = z[0].cpu()
a = a + b
# [*, N_res, N_res]
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = self.inf * (square_mask - 1)
a = a + square_mask.unsqueeze(-1)
a = a * math.sqrt(1. / 3) # Normalize by number of logit terms (3)
a = self.softmax(a)
# [*, N_res, H * C_hidden]
v = self.linear_v(s)
# [*, N_res, H, C_hidden]
v = v.view(v.shape[:-1] + (self.no_heads, -1))
o = torch.einsum('...qkh, ...khc->...qhc', a, v)
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, N_res, H, P_v, 3]
v_pts = Vec3Array.from_array(self.linear_v_points(s, r))
# [*, N_res, H, P_v]
o_pt = v_pts[..., None, :, :, :] * a.unsqueeze(-1)
o_pt = o_pt.sum(dim=-3)
# o_pt = Vec3Array(
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].x, dim=-3),
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].y, dim=-3),
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].z, dim=-3),
# )
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,))
# [*, N_res, H, P_v]
o_pt = r[..., None].apply_inverse_to_point(o_pt)
o_pt_flat = [o_pt.x, o_pt.y, o_pt.z]
# [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(epsilon=1e-8)
if (_offload_inference):
z[0] = z[0].to(o_pt.device)
o_pair = torch.einsum('...ijh, ...ijc->...ihc', a, z[0].to(dtype=a.dtype))
# [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2)
# [*, N_res, C_s]
s = self.linear_out(
torch.cat(
(o, *o_pt_flat, o_pt_norm, o_pair), dim=-1
).to(dtype=z[0].dtype)
)
return s
class BackboneUpdate(nn.Module):
"""
Implements part of Algorithm 23.
......@@ -670,7 +895,8 @@ class StructureModule(nn.Module):
self.linear_in = Linear(self.c_s, self.c_s)
self.ipa = InvariantPointAttention(
ipa = InvariantPointAttention if not self.is_multimer else InvariantPointAttentionMultimer
self.ipa = ipa(
self.c_s,
self.c_z,
self.c_ipa,
......
......@@ -521,12 +521,8 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
def compute_projection(pair, mask):
p = compute_projection_helper(pair, mask)
if self._outgoing:
left = p[..., :self.c_hidden]
right = p[..., self.c_hidden:]
else:
left = p[..., self.c_hidden:]
right = p[..., :self.c_hidden]
left = p[..., :self.c_hidden]
right = p[..., self.c_hidden:]
return left, right
......@@ -580,12 +576,8 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
ab = ab * self.sigmoid(self.linear_ab_g(z))
ab = ab * self.linear_ab_p(z)
if self._outgoing:
a = ab[..., :self.c_hidden]
b = ab[..., self.c_hidden:]
else:
b = ab[..., :self.c_hidden]
a = ab[..., self.c_hidden:]
a = ab[..., :self.c_hidden]
b = ab[..., self.c_hidden:]
# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
......
......@@ -36,9 +36,9 @@ def get_rc_tensor(rc_np, aatype):
def atom14_to_atom37(
atom14_data: torch.Tensor, # (*, N, 14, ...)
aatype: torch.Tensor # (*, N)
) -> torch.Tensor: # (*, N, 37, ...)
) -> Tuple: # (*, N, 37, ...)
"""Convert atom14 to atom37 representation."""
idx_atom37_to_atom14 = get_rc_tensor(rc.RESTYPE_ATOM37_TO_ATOM14, aatype)
idx_atom37_to_atom14 = get_rc_tensor(rc.RESTYPE_ATOM37_TO_ATOM14, aatype).long()
no_batch_dims = len(aatype.shape) - 1
atom37_data = tensor_utils.batched_gather(
atom14_data,
......@@ -50,10 +50,10 @@ def atom14_to_atom37(
if len(atom14_data.shape) == no_batch_dims + 2:
atom37_data *= atom37_mask
elif len(atom14_data.shape) == no_batch_dims + 3:
atom37_data *= atom37_mask[..., None].astype(atom37_data.dtype)
atom37_data *= atom37_mask[..., None].to(dtype=atom37_data.dtype)
else:
raise ValueError("Incorrectly shaped data")
return atom37_data
return atom37_data, atom37_mask
def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask):
......@@ -230,13 +230,13 @@ def torsion_angles_to_frames(
num_residues = aatype.shape[-1]
sin_angles = torch.cat(
[
torch.zeros_like(aatype).unsqueeze(),
torch.zeros_like(aatype).unsqueeze(dim=-1),
sin_angles,
],
dim=-1)
cos_angles = torch.cat(
[
torch.ones_like(aatype).unsqueeze(),
torch.ones_like(aatype).unsqueeze(dim=-1),
cos_angles
],
dim=-1
......
......@@ -20,7 +20,7 @@ class QuatRigid(nn.Module):
def forward(self, activations: torch.Tensor) -> Rigid3Array:
# NOTE: During training, this needs to be run in higher precision
rigid_flat = self.linear(activations.to(torch.float32))
rigid_flat = self.linear(activations)
rigid_flat = torch.unbind(rigid_flat, dim=-1)
if(self.full_quat):
......
......@@ -172,20 +172,20 @@ class Rot3Array:
) -> Rot3Array:
"""Construct Rot3Array from components of quaternion."""
if normalize:
inv_norm = torch.rsqrt(eps + w**2 + x**2 + y**2 + z**2)
inv_norm = torch.rsqrt(torch.clamp(w**2 + x**2 + y**2 + z**2, min=eps))
w = w * inv_norm
x = x * inv_norm
y = y * inv_norm
z = z * inv_norm
xx = 1 - 2 * (y ** 2 + z ** 2)
xy = 2 * (x * y - w * z)
xz = 2 * (x * z + w * y)
yx = 2 * (x * y + w * z)
yy = 1 - 2 * (x ** 2 + z ** 2)
yz = 2 * (y * z - w * x)
zx = 2 * (x * z - w * y)
zy = 2 * (y * z + w * x)
zz = 1 - 2 * (x ** 2 + y ** 2)
xx = 1.0 - 2.0 * (y ** 2 + z ** 2)
xy = 2.0 * (x * y - w * z)
xz = 2.0 * (x * z + w * y)
yx = 2.0 * (x * y + w * z)
yy = 1.0 - 2.0 * (x ** 2 + z ** 2)
yz = 2.0 * (y * z - w * x)
zx = 2.0 * (x * z - w * y)
zy = 2.0 * (y * z + w * x)
zz = 1.0 - 2.0 * (x ** 2 + y ** 2)
return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)
def reshape(self, new_shape):
......
......@@ -28,7 +28,7 @@ _NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
# With Param, a poor man's enum with attributes (Rust-style)
class ParamType(Enum):
LinearWeight = partial( # hack: partial prevents fns from becoming methods
lambda w: w.transpose(-1, -2)
lambda w: w.unsqueeze(-1) if len(w.shape) == 1 else w.transpose(-1, -2)
)
LinearWeightMHA = partial(
lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2)
......@@ -58,6 +58,7 @@ class Param:
param: Union[torch.Tensor, List[torch.Tensor]]
param_type: ParamType = ParamType.Other
stacked: bool = False
swap: bool = False
def process_translation_dict(d, top_layer=True):
......@@ -101,6 +102,7 @@ def stacked(param_dict_list, out=None):
param=[param.param for param in v],
param_type=v[0].param_type,
stacked=True,
swap=v[0].swap
)
out[k] = stacked_param
......@@ -122,7 +124,12 @@ def assign(translation_dict, orig_weights):
try:
weights = list(map(param_type.transformation, weights))
for p, w in zip(ref, weights):
p.copy_(w)
if param.swap:
index = p.shape[0] // 2
p[:index].copy_(w[index:])
p[index:].copy_(w[:index])
else:
p.copy_(w)
except:
print(k)
print(ref[0].shape)
......@@ -145,12 +152,24 @@ def generate_translation_dict(model, version, is_multimer=False):
LinearBiasMultimer = lambda l: (
Param(l, param_type=ParamType.LinearBiasMultimer)
)
LinearWeightSwap = lambda l: (Param(l, param_type=ParamType.LinearWeight, swap=True))
LinearBiasSwap = lambda l: (Param(l, swap=True))
LinearParams = lambda l: {
"weights": LinearWeight(l.weight),
"bias": LinearBias(l.bias),
}
LinearParamsMHA = lambda l: {
"weights": LinearWeightMHA(l.weight),
"bias": LinearBiasMHA(l.bias),
}
LinearParamsSwap = lambda l: {
"weights": LinearWeightSwap(l.weight),
"bias": LinearBiasSwap(l.bias),
}
LinearParamsMultimer = lambda l: {
"weights": LinearWeightMultimer(l.weight),
"bias": LinearBiasMultimer(l.bias),
......@@ -194,10 +213,11 @@ def generate_translation_dict(model, version, is_multimer=False):
def TriMulOutParams(tri_mul, outgoing=True):
if re.fullmatch("^model_[1-5]_multimer_v3$", version):
lin_param_type = LinearParams if outgoing else LinearParamsSwap
d = {
"left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"projection": LinearParams(tri_mul.linear_ab_p),
"gate": LinearParams(tri_mul.linear_ab_g),
"projection": lin_param_type(tri_mul.linear_ab_p),
"gate": lin_param_type(tri_mul.linear_ab_g),
"center_norm": LayerNormParams(tri_mul.layer_norm_out),
}
else:
......@@ -276,24 +296,24 @@ def generate_translation_dict(model, version, is_multimer=False):
}
PointProjectionParams = lambda pp: {
"point_projection": LinearParamsMultimer(
"point_projection": LinearParamsMHA(
pp.linear,
),
}
IPAParamsMultimer = lambda ipa: {
"q_scalar_projection": {
"weights": LinearWeightMultimer(
"weights": LinearWeightMHA(
ipa.linear_q.weight,
),
},
"k_scalar_projection": {
"weights": LinearWeightMultimer(
"weights": LinearWeightMHA(
ipa.linear_k.weight,
),
},
"v_scalar_projection": {
"weights": LinearWeightMultimer(
"weights": LinearWeightMHA(
ipa.linear_v.weight,
),
},
......@@ -574,7 +594,7 @@ def generate_translation_dict(model, version, is_multimer=False):
"template_pair_embedding_0": LinearParams(
temp_embedder.template_pair_embedder.dgram_linear
),
"template_pair_embedding_1": LinearParamsMultimer(
"template_pair_embedding_1": LinearParams(
temp_embedder.template_pair_embedder.pseudo_beta_mask_linear
),
"template_pair_embedding_2": LinearParams(
......@@ -583,16 +603,16 @@ def generate_translation_dict(model, version, is_multimer=False):
"template_pair_embedding_3": LinearParams(
temp_embedder.template_pair_embedder.aatype_linear_2
),
"template_pair_embedding_4": LinearParamsMultimer(
"template_pair_embedding_4": LinearParams(
temp_embedder.template_pair_embedder.x_linear
),
"template_pair_embedding_5": LinearParamsMultimer(
"template_pair_embedding_5": LinearParams(
temp_embedder.template_pair_embedder.y_linear
),
"template_pair_embedding_6": LinearParamsMultimer(
"template_pair_embedding_6": LinearParams(
temp_embedder.template_pair_embedder.z_linear
),
"template_pair_embedding_7": LinearParamsMultimer(
"template_pair_embedding_7": LinearParams(
temp_embedder.template_pair_embedder.backbone_mask_linear
),
"template_pair_embedding_8": LinearParams(
......@@ -600,7 +620,7 @@ def generate_translation_dict(model, version, is_multimer=False):
),
"template_embedding_iteration": tps_blocks_params,
"output_layer_norm": LayerNormParams(
model.template_embedder.template_pair_stack.layer_norm
temp_embedder.template_pair_stack.layer_norm
),
},
"output_linear": LinearParams(
......
......@@ -1643,7 +1643,7 @@ def chain_center_of_mass_loss(
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
chains, _ = asym_id.unique(return_counts=True)
chains = asym_id.unique()
one_hot = torch.nn.functional.one_hot(asym_id, num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype)
one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1)
......
......@@ -431,7 +431,7 @@ if __name__ == "__main__":
help="""Postfix for output prediction filenames"""
)
parser.add_argument(
"--data_random_seed", type=str, default=None
"--data_random_seed", type=int, default=None
)
parser.add_argument(
"--skip_relaxation", action="store_true", default=False,
......
......@@ -45,7 +45,7 @@ def main(args):
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniref30_database_path=args.uniref30_database_path,
small_bfd_database_path=None,
template_featurizer=template_featurizer,
template_searcher=template_searcher,
......
......@@ -15,7 +15,9 @@
import torch
import numpy as np
import unittest
from pathlib import Path
from tests.config import consts
from openfold.config import model_config
from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_
......@@ -23,15 +25,17 @@ from openfold.utils.import_weights import import_jax_weights_
class TestImportWeights(unittest.TestCase):
def test_import_jax_weights_(self):
npz_path = "openfold/resources/params/params_model_1_ptm.npz"
npz_path = Path(__file__).parent.resolve() / f"../openfold/resources/params/params_{consts.model}.npz"
c = model_config("model_1_ptm")
c = model_config(consts.model)
c.globals.blocks_per_ckpt = None
model = AlphaFold(c)
model.eval()
import_jax_weights_(
model,
npz_path,
version=consts.model
)
data = np.load(npz_path)
......
......@@ -22,7 +22,7 @@ from openfold.data.data_modules import (
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants
from openfold.utils.argparse import remove_arguments
from openfold.utils.argparse_utils import remove_arguments
from openfold.utils.callbacks import (
EarlyStoppingVerbose,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment