Commit 9d4c9357 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Speed up template featurizer

parent e69b2a11
......@@ -211,8 +211,9 @@ config = mlc.ConfigDict(
"subsample_templates": True,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_template_hits": 20,
"max_template_hits": 4,
"max_templates": 4,
"shuffle_top_k_prefiltered": 20,
"crop": True,
"crop_size": 256,
"supervised": True,
......
......@@ -32,9 +32,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mapping_path: Optional[str] = None,
max_template_hits: int = 4,
template_release_dates_cache_path: Optional[str] = None,
use_small_bfd: bool = True,
output_raw: bool = False,
shuffle_top_k_prefiltered: Optional[int] = None,
mode: str = "train",
_output_raw: bool = False,
):
"""
Args:
......@@ -48,21 +48,38 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
files.
template_mmcif_dir:
Path to a directory containing template mmCIF files.
config:
A dataset config object. See openfold.config
kalign_binary_path:
Path to kalign binary.
mapping_path:
A json file containing a mapping from consecutive numerical
ids to sample names (matching the directories in data_dir).
Samples not in this mapping are ignored. Can be used to
implement the various training-time filters described in
the AlphaFold supplement
the AlphaFold supplement.
max_template_hits:
An upper bound on how many templates are considered. During
training, the templates ultimately used are subsampled
from this total quantity.
template_release_dates_cache_path:
Path to the output of scripts/generate_mmcif_cache.
shuffle_top_k_prefiltered:
Whether to uniformly shuffle the top k template hits before
parsing max_template_hits of them. Can be used to
approximate DeepMind's training-time template subsampling
scheme much more performantly.
mode:
"train", "val", or "predict"
"""
super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir
self.alignment_dir = alignment_dir
self.config = config
self.output_raw = output_raw
self.mode = mode
self._output_raw = _output_raw
valid_modes = ["train", "val", "predict"]
if(mode not in valid_modes):
......@@ -90,13 +107,14 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
kalign_binary_path=kalign_binary_path,
release_dates_path=template_release_dates_cache_path,
obsolete_pdbs_path=None,
_shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
)
self.data_pipeline = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
)
if(not self.output_raw):
if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir):
......@@ -153,7 +171,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir=alignment_dir,
)
if(self.output_raw):
if(self._output_raw):
return data
feats = self.feature_pipeline.process_features(
......@@ -357,7 +375,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path=
self.template_release_dates_cache_path,
use_small_bfd=self.config.data_module.use_small_bfd,
)
if(self.training_mode):
......@@ -366,8 +383,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path,
max_template_hits=self.config.train.max_template_hits,
output_raw=True,
shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered,
mode="train",
_output_raw=True,
)
if(self.distillation_data_dir is not None):
......@@ -376,8 +395,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
alignment_dir=self.distillation_alignment_dir,
mapping_path=self.distillation_mapping_path,
max_template_hits=self.train.max_template_hits,
output_raw=True,
mode="train",
_output_raw=True,
)
d_prob = self.config.train.distillation_prob
......
......@@ -123,14 +123,15 @@ def _is_after_cutoff(
Returns:
True if the template release date is after the cutoff, False otherwise.
"""
pdb_id_upper = pdb_id.upper()
if release_date_cutoff is None:
raise ValueError("The release_date_cutoff must not be None.")
if pdb_id in release_dates:
return release_dates[pdb_id] > release_date_cutoff
if pdb_id_upper in release_dates:
return release_dates[pdb_id_upper] > release_date_cutoff
else:
# Since this is just a quick prefilter to reduce the number of mmCIF files
# we need to parse, we don't have to worry about returning True here.
logging.info(
logging.warning(
"Template structure not in release dates dict: %s", pdb_id
)
return False
......@@ -183,7 +184,7 @@ def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
data = json.load(fp)
return {
pdb: to_date(v)
pdb.upper(): to_date(v)
for pdb, d in data.items()
for k, v in d.items()
if k == "release_date"
......@@ -239,8 +240,9 @@ def _assess_hhsearch_hit(
)
if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
date = release_dates[hit_pdb_code.upper()]
raise DateError(
f"Date ({release_dates[hit_pdb_code]}) > max template date "
f"Date ({date}) > max template date "
f"({release_date_cutoff})."
)
......@@ -735,6 +737,12 @@ def _build_query_to_hit_index_mapping(
return mapping
@dataclasses.dataclass(frozen=True)
class PrefilterResult:
valid: bool
error: Optional[str]
warning: Optional[str]
@dataclasses.dataclass(frozen=True)
class SingleHitResult:
features: Optional[Mapping[str, Any]]
......@@ -742,18 +750,15 @@ class SingleHitResult:
warning: Optional[str]
def _process_single_hit(
def _prefilter_hit(
query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit,
mmcif_dir: str,
max_template_date: datetime.datetime,
release_dates: Mapping[str, datetime.datetime],
obsolete_pdbs: Mapping[str, str],
kalign_binary_path: str,
strict_error_check: bool = False,
) -> SingleHitResult:
"""Tries to extract template features from a single HHSearch hit."""
):
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
......@@ -761,7 +766,8 @@ def _process_single_hit(
if hit_pdb_code in obsolete_pdbs:
hit_pdb_code = obsolete_pdbs[hit_pdb_code]
# Pass hit_pdb_code since it might have changed due to the pdb being obsolete.
# Pass hit_pdb_code since it might have changed due to the pdb being
# obsolete.
try:
_assess_hhsearch_hit(
hit=hit,
......@@ -772,15 +778,32 @@ def _process_single_hit(
release_date_cutoff=max_template_date,
)
except PrefilterError as e:
msg = f"hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}"
hit_name = f"{hit_pdb_code}_{hit_chain_id}"
msg = f"hit {hit_name} did not pass prefilter: {str(e)}"
logging.info("%s: %s", query_pdb_code, msg)
if strict_error_check and isinstance(
e, (DateError, PdbIdError, DuplicateError)
):
# In strict mode we treat some prefilter cases as errors.
return SingleHitResult(features=None, error=msg, warning=None)
return PrefilterResult(valid=False, error=msg, warning=None)
return SingleHitResult(features=None, error=None, warning=None)
return PrefilterResult(valid=False, error=None, warning=None)
return PrefilterResult(valid=True, error=None, warning=None)
def _process_single_hit(
query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit,
mmcif_dir: str,
max_template_date: datetime.datetime,
kalign_binary_path: str,
strict_error_check: bool = False,
) -> SingleHitResult:
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
mapping = _build_query_to_hit_index_mapping(
hit.query,
......@@ -901,6 +924,7 @@ class TemplateHitFeaturizer:
release_dates_path: Optional[str],
obsolete_pdbs_path: Optional[str],
strict_error_check: bool = False,
_shuffle_top_k_prefiltered: Optional[int] = None,
):
"""Initializes the Template Search.
......@@ -938,7 +962,7 @@ class TemplateHitFeaturizer:
raise ValueError(
"max_template_date must be set and have format YYYY-MM-DD."
)
self._max_hits = max_hits
self.max_hits = max_hits
self._kalign_binary_path = kalign_binary_path
self._strict_error_check = strict_error_check
......@@ -958,6 +982,8 @@ class TemplateHitFeaturizer:
else:
self._obsolete_pdbs = {}
self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
def get_templates(
self,
query_sequence: str,
......@@ -986,19 +1012,48 @@ class TemplateHitFeaturizer:
errors = []
warnings = []
for hit in sorted(hits, key=lambda x: x.sum_probs, reverse=True):
filtered = []
for hit in hits:
prefilter_result = _prefilter_hit(
query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
hit=hit,
max_template_date=template_cutoff_date,
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
)
if prefilter_result.error:
errors.append(prefilter_result.error)
if prefilter_result.warning:
warnings.append(prefilter_result.warning)
if prefilter_result.valid:
filtered.append(hit)
filtered = list(
sorted(filtered, key=lambda x: x.sum_probs, reverse=True)
)
idx = list(range(len(filtered)))
if(self._shuffle_top_k_prefiltered):
stk = self._shuffle_top_k_prefiltered
idx[:stk] = np.random.permutation(idx[:stk])
for i in idx:
# We got all the templates we wanted, stop processing hits.
if num_hits >= self._max_hits:
if num_hits >= self.max_hits:
break
hit = filtered[i]
result = _process_single_hit(
query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
hit=hit,
mmcif_dir=self._mmcif_dir,
max_template_date=template_cutoff_date,
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
kalign_binary_path=self._kalign_binary_path,
)
......
......@@ -259,7 +259,7 @@ class TemplatePairStack(nn.Module):
self.blocks_per_ckpt = blocks_per_ckpt
self.blocks = nn.ModuleList()
for i in range(no_blocks):
for _ in range(no_blocks):
block = TemplatePairStackBlock(
c_t=c_t,
c_hidden_tri_att=c_hidden_tri_att,
......
......@@ -90,6 +90,7 @@ def compute_fape(
local_target_pos = target_frames.invert()[..., None].apply(
target_positions[..., None, :, :],
)
error_dist = torch.sqrt(
torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps
)
......@@ -161,7 +162,9 @@ def backbone_loss(
1 - use_clamped_fape
)
# Average over the batch dimension
fape_loss = torch.mean(fape_loss)
return fape_loss
......@@ -231,7 +234,12 @@ def fape_loss(
**{**batch, **config.sidechain},
)
return config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss
loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss
# Average over the batch dimension
loss = torch.mean(loss)
return loss
def supervised_chi_loss(
......@@ -290,6 +298,9 @@ def supervised_chi_loss(
loss = loss + angle_norm_weight * angle_norm_loss
# Average over the batch dimension
loss = torch.mean(loss)
return loss
......@@ -388,6 +399,9 @@ def lddt_loss(
(resolution >= min_resolution) & (resolution <= max_resolution)
)
# Average over the batch dimension
loss = torch.mean(loss)
return loss
......@@ -433,6 +447,9 @@ def distogram_loss(
mean = mean / denom[..., None]
mean = torch.sum(mean, dim=-1)
# Average over the batch dimensions
mean = torch.mean(mean)
return mean
......@@ -580,6 +597,9 @@ def tm_loss(
(resolution >= min_resolution) & (resolution <= max_resolution)
)
# Average over the loss dimension
loss = torch.mean(loss)
return loss
......@@ -1351,6 +1371,8 @@ def experimentally_resolved_loss(
(resolution >= min_resolution) & (resolution <= max_resolution)
)
loss = torch.mean(loss)
return loss
......@@ -1469,8 +1491,8 @@ class AlphaFoldLoss(nn.Module):
}
cum_loss = 0
for k, loss_fn in loss_fns.items():
weight = self.config[k].weight
for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight
if weight:
loss = loss_fn()
cum_loss = cum_loss + weight * loss
......
......@@ -50,12 +50,10 @@ def main(args):
model = model.to(args.model_device)
# FEATURE COLLECTION AND PROCESSING
num_ensemble = 1
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=args.max_template_hits,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=None,
obsolete_pdbs_path=args.obsolete_pdbs_path
......@@ -85,7 +83,6 @@ def main(args):
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
config.data.predict.num_ensemble = num_ensemble
feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
......
......@@ -40,9 +40,6 @@ def add_data_args(parser: argparse.ArgumentParser):
'--max_template_date', type=str,
default=date.today().strftime("%Y-%m-%d"),
)
parser.add_argument(
'--max_template_hits', type=int, default=20,
)
parser.add_argument(
'--obsolete_pdbs_path', type=str, default=None
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment