Commit 9d4c9357 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Speed up template featurizer

parent e69b2a11
...@@ -211,8 +211,9 @@ config = mlc.ConfigDict( ...@@ -211,8 +211,9 @@ config = mlc.ConfigDict(
"subsample_templates": True, "subsample_templates": True,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_template_hits": 20, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"shuffle_top_k_prefiltered": 20,
"crop": True, "crop": True,
"crop_size": 256, "crop_size": 256,
"supervised": True, "supervised": True,
......
...@@ -32,9 +32,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -32,9 +32,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mapping_path: Optional[str] = None, mapping_path: Optional[str] = None,
max_template_hits: int = 4, max_template_hits: int = 4,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
use_small_bfd: bool = True, shuffle_top_k_prefiltered: Optional[int] = None,
output_raw: bool = False,
mode: str = "train", mode: str = "train",
_output_raw: bool = False,
): ):
""" """
Args: Args:
...@@ -48,21 +48,38 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -48,21 +48,38 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID} I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing .a3m, .sto, and .hhr or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
files. files.
template_mmcif_dir:
Path to a directory containing template mmCIF files.
config: config:
A dataset config object. See openfold.config A dataset config object. See openfold.config
kalign_binary_path:
Path to kalign binary.
mapping_path: mapping_path:
A json file containing a mapping from consecutive numerical A json file containing a mapping from consecutive numerical
ids to sample names (matching the directories in data_dir). ids to sample names (matching the directories in data_dir).
Samples not in this mapping are ignored. Can be used to Samples not in this mapping are ignored. Can be used to
implement the various training-time filters described in implement the various training-time filters described in
the AlphaFold supplement the AlphaFold supplement.
max_template_hits:
An upper bound on how many templates are considered. During
training, the templates ultimately used are subsampled
from this total quantity.
template_release_dates_cache_path:
Path to the output of scripts/generate_mmcif_cache.
shuffle_top_k_prefiltered:
Whether to uniformly shuffle the top k template hits before
parsing max_template_hits of them. Can be used to
approximate DeepMind's training-time template subsampling
scheme much more performantly.
mode:
"train", "val", or "predict"
""" """
super(OpenFoldSingleDataset, self).__init__() super(OpenFoldSingleDataset, self).__init__()
self.data_dir = data_dir self.data_dir = data_dir
self.alignment_dir = alignment_dir self.alignment_dir = alignment_dir
self.config = config self.config = config
self.output_raw = output_raw
self.mode = mode self.mode = mode
self._output_raw = _output_raw
valid_modes = ["train", "val", "predict"] valid_modes = ["train", "val", "predict"]
if(mode not in valid_modes): if(mode not in valid_modes):
...@@ -90,13 +107,14 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -90,13 +107,14 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
kalign_binary_path=kalign_binary_path, kalign_binary_path=kalign_binary_path,
release_dates_path=template_release_dates_cache_path, release_dates_path=template_release_dates_cache_path,
obsolete_pdbs_path=None, obsolete_pdbs_path=None,
_shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
) )
self.data_pipeline = data_pipeline.DataPipeline( self.data_pipeline = data_pipeline.DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
) )
if(not self.output_raw): if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir): def _parse_mmcif(self, path, file_id, chain_id, alignment_dir):
...@@ -153,7 +171,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -153,7 +171,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
) )
if(self.output_raw): if(self._output_raw):
return data return data
feats = self.feature_pipeline.process_features( feats = self.feature_pipeline.process_features(
...@@ -357,7 +375,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -357,7 +375,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
kalign_binary_path=self.kalign_binary_path, kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path= template_release_dates_cache_path=
self.template_release_dates_cache_path, self.template_release_dates_cache_path,
use_small_bfd=self.config.data_module.use_small_bfd,
) )
if(self.training_mode): if(self.training_mode):
...@@ -366,8 +383,10 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -366,8 +383,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
alignment_dir=self.train_alignment_dir, alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path, mapping_path=self.train_mapping_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=self.config.train.max_template_hits,
output_raw=True, shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered,
mode="train", mode="train",
_output_raw=True,
) )
if(self.distillation_data_dir is not None): if(self.distillation_data_dir is not None):
...@@ -376,8 +395,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -376,8 +395,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
alignment_dir=self.distillation_alignment_dir, alignment_dir=self.distillation_alignment_dir,
mapping_path=self.distillation_mapping_path, mapping_path=self.distillation_mapping_path,
max_template_hits=self.train.max_template_hits, max_template_hits=self.train.max_template_hits,
output_raw=True,
mode="train", mode="train",
_output_raw=True,
) )
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
......
...@@ -123,14 +123,15 @@ def _is_after_cutoff( ...@@ -123,14 +123,15 @@ def _is_after_cutoff(
Returns: Returns:
True if the template release date is after the cutoff, False otherwise. True if the template release date is after the cutoff, False otherwise.
""" """
pdb_id_upper = pdb_id.upper()
if release_date_cutoff is None: if release_date_cutoff is None:
raise ValueError("The release_date_cutoff must not be None.") raise ValueError("The release_date_cutoff must not be None.")
if pdb_id in release_dates: if pdb_id_upper in release_dates:
return release_dates[pdb_id] > release_date_cutoff return release_dates[pdb_id_upper] > release_date_cutoff
else: else:
# Since this is just a quick prefilter to reduce the number of mmCIF files # Since this is just a quick prefilter to reduce the number of mmCIF files
# we need to parse, we don't have to worry about returning True here. # we need to parse, we don't have to worry about returning True here.
logging.info( logging.warning(
"Template structure not in release dates dict: %s", pdb_id "Template structure not in release dates dict: %s", pdb_id
) )
return False return False
...@@ -183,7 +184,7 @@ def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]: ...@@ -183,7 +184,7 @@ def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
data = json.load(fp) data = json.load(fp)
return { return {
pdb: to_date(v) pdb.upper(): to_date(v)
for pdb, d in data.items() for pdb, d in data.items()
for k, v in d.items() for k, v in d.items()
if k == "release_date" if k == "release_date"
...@@ -239,8 +240,9 @@ def _assess_hhsearch_hit( ...@@ -239,8 +240,9 @@ def _assess_hhsearch_hit(
) )
if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff): if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
date = release_dates[hit_pdb_code.upper()]
raise DateError( raise DateError(
f"Date ({release_dates[hit_pdb_code]}) > max template date " f"Date ({date}) > max template date "
f"({release_date_cutoff})." f"({release_date_cutoff})."
) )
...@@ -735,6 +737,12 @@ def _build_query_to_hit_index_mapping( ...@@ -735,6 +737,12 @@ def _build_query_to_hit_index_mapping(
return mapping return mapping
@dataclasses.dataclass(frozen=True)
class PrefilterResult:
valid: bool
error: Optional[str]
warning: Optional[str]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class SingleHitResult: class SingleHitResult:
features: Optional[Mapping[str, Any]] features: Optional[Mapping[str, Any]]
...@@ -742,18 +750,15 @@ class SingleHitResult: ...@@ -742,18 +750,15 @@ class SingleHitResult:
warning: Optional[str] warning: Optional[str]
def _process_single_hit( def _prefilter_hit(
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str], query_pdb_code: Optional[str],
hit: parsers.TemplateHit, hit: parsers.TemplateHit,
mmcif_dir: str,
max_template_date: datetime.datetime, max_template_date: datetime.datetime,
release_dates: Mapping[str, datetime.datetime], release_dates: Mapping[str, datetime.datetime],
obsolete_pdbs: Mapping[str, str], obsolete_pdbs: Mapping[str, str],
kalign_binary_path: str,
strict_error_check: bool = False, strict_error_check: bool = False,
) -> SingleHitResult: ):
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit. # Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit) hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
...@@ -761,7 +766,8 @@ def _process_single_hit( ...@@ -761,7 +766,8 @@ def _process_single_hit(
if hit_pdb_code in obsolete_pdbs: if hit_pdb_code in obsolete_pdbs:
hit_pdb_code = obsolete_pdbs[hit_pdb_code] hit_pdb_code = obsolete_pdbs[hit_pdb_code]
# Pass hit_pdb_code since it might have changed due to the pdb being obsolete. # Pass hit_pdb_code since it might have changed due to the pdb being
# obsolete.
try: try:
_assess_hhsearch_hit( _assess_hhsearch_hit(
hit=hit, hit=hit,
...@@ -772,15 +778,32 @@ def _process_single_hit( ...@@ -772,15 +778,32 @@ def _process_single_hit(
release_date_cutoff=max_template_date, release_date_cutoff=max_template_date,
) )
except PrefilterError as e: except PrefilterError as e:
msg = f"hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}" hit_name = f"{hit_pdb_code}_{hit_chain_id}"
msg = f"hit {hit_name} did not pass prefilter: {str(e)}"
logging.info("%s: %s", query_pdb_code, msg) logging.info("%s: %s", query_pdb_code, msg)
if strict_error_check and isinstance( if strict_error_check and isinstance(
e, (DateError, PdbIdError, DuplicateError) e, (DateError, PdbIdError, DuplicateError)
): ):
# In strict mode we treat some prefilter cases as errors. # In strict mode we treat some prefilter cases as errors.
return SingleHitResult(features=None, error=msg, warning=None) return PrefilterResult(valid=False, error=msg, warning=None)
return SingleHitResult(features=None, error=None, warning=None) return PrefilterResult(valid=False, error=None, warning=None)
return PrefilterResult(valid=True, error=None, warning=None)
def _process_single_hit(
query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit,
mmcif_dir: str,
max_template_date: datetime.datetime,
kalign_binary_path: str,
strict_error_check: bool = False,
) -> SingleHitResult:
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
mapping = _build_query_to_hit_index_mapping( mapping = _build_query_to_hit_index_mapping(
hit.query, hit.query,
...@@ -901,6 +924,7 @@ class TemplateHitFeaturizer: ...@@ -901,6 +924,7 @@ class TemplateHitFeaturizer:
release_dates_path: Optional[str], release_dates_path: Optional[str],
obsolete_pdbs_path: Optional[str], obsolete_pdbs_path: Optional[str],
strict_error_check: bool = False, strict_error_check: bool = False,
_shuffle_top_k_prefiltered: Optional[int] = None,
): ):
"""Initializes the Template Search. """Initializes the Template Search.
...@@ -938,7 +962,7 @@ class TemplateHitFeaturizer: ...@@ -938,7 +962,7 @@ class TemplateHitFeaturizer:
raise ValueError( raise ValueError(
"max_template_date must be set and have format YYYY-MM-DD." "max_template_date must be set and have format YYYY-MM-DD."
) )
self._max_hits = max_hits self.max_hits = max_hits
self._kalign_binary_path = kalign_binary_path self._kalign_binary_path = kalign_binary_path
self._strict_error_check = strict_error_check self._strict_error_check = strict_error_check
...@@ -958,6 +982,8 @@ class TemplateHitFeaturizer: ...@@ -958,6 +982,8 @@ class TemplateHitFeaturizer:
else: else:
self._obsolete_pdbs = {} self._obsolete_pdbs = {}
self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
def get_templates( def get_templates(
self, self,
query_sequence: str, query_sequence: str,
...@@ -986,19 +1012,48 @@ class TemplateHitFeaturizer: ...@@ -986,19 +1012,48 @@ class TemplateHitFeaturizer:
errors = [] errors = []
warnings = [] warnings = []
for hit in sorted(hits, key=lambda x: x.sum_probs, reverse=True): filtered = []
for hit in hits:
prefilter_result = _prefilter_hit(
query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
hit=hit,
max_template_date=template_cutoff_date,
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
)
if prefilter_result.error:
errors.append(prefilter_result.error)
if prefilter_result.warning:
warnings.append(prefilter_result.warning)
if prefilter_result.valid:
filtered.append(hit)
filtered = list(
sorted(filtered, key=lambda x: x.sum_probs, reverse=True)
)
idx = list(range(len(filtered)))
if(self._shuffle_top_k_prefiltered):
stk = self._shuffle_top_k_prefiltered
idx[:stk] = np.random.permutation(idx[:stk])
for i in idx:
# We got all the templates we wanted, stop processing hits. # We got all the templates we wanted, stop processing hits.
if num_hits >= self._max_hits: if num_hits >= self.max_hits:
break break
hit = filtered[i]
result = _process_single_hit( result = _process_single_hit(
query_sequence=query_sequence, query_sequence=query_sequence,
query_pdb_code=query_pdb_code, query_pdb_code=query_pdb_code,
hit=hit, hit=hit,
mmcif_dir=self._mmcif_dir, mmcif_dir=self._mmcif_dir,
max_template_date=template_cutoff_date, max_template_date=template_cutoff_date,
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check, strict_error_check=self._strict_error_check,
kalign_binary_path=self._kalign_binary_path, kalign_binary_path=self._kalign_binary_path,
) )
......
...@@ -259,7 +259,7 @@ class TemplatePairStack(nn.Module): ...@@ -259,7 +259,7 @@ class TemplatePairStack(nn.Module):
self.blocks_per_ckpt = blocks_per_ckpt self.blocks_per_ckpt = blocks_per_ckpt
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
for i in range(no_blocks): for _ in range(no_blocks):
block = TemplatePairStackBlock( block = TemplatePairStackBlock(
c_t=c_t, c_t=c_t,
c_hidden_tri_att=c_hidden_tri_att, c_hidden_tri_att=c_hidden_tri_att,
......
...@@ -90,6 +90,7 @@ def compute_fape( ...@@ -90,6 +90,7 @@ def compute_fape(
local_target_pos = target_frames.invert()[..., None].apply( local_target_pos = target_frames.invert()[..., None].apply(
target_positions[..., None, :, :], target_positions[..., None, :, :],
) )
error_dist = torch.sqrt( error_dist = torch.sqrt(
torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps
) )
...@@ -161,7 +162,9 @@ def backbone_loss( ...@@ -161,7 +162,9 @@ def backbone_loss(
1 - use_clamped_fape 1 - use_clamped_fape
) )
# Average over the batch dimension
fape_loss = torch.mean(fape_loss) fape_loss = torch.mean(fape_loss)
return fape_loss return fape_loss
...@@ -231,7 +234,12 @@ def fape_loss( ...@@ -231,7 +234,12 @@ def fape_loss(
**{**batch, **config.sidechain}, **{**batch, **config.sidechain},
) )
return config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss
# Average over the batch dimension
loss = torch.mean(loss)
return loss
def supervised_chi_loss( def supervised_chi_loss(
...@@ -290,6 +298,9 @@ def supervised_chi_loss( ...@@ -290,6 +298,9 @@ def supervised_chi_loss(
loss = loss + angle_norm_weight * angle_norm_loss loss = loss + angle_norm_weight * angle_norm_loss
# Average over the batch dimension
loss = torch.mean(loss)
return loss return loss
...@@ -388,6 +399,9 @@ def lddt_loss( ...@@ -388,6 +399,9 @@ def lddt_loss(
(resolution >= min_resolution) & (resolution <= max_resolution) (resolution >= min_resolution) & (resolution <= max_resolution)
) )
# Average over the batch dimension
loss = torch.mean(loss)
return loss return loss
...@@ -433,6 +447,9 @@ def distogram_loss( ...@@ -433,6 +447,9 @@ def distogram_loss(
mean = mean / denom[..., None] mean = mean / denom[..., None]
mean = torch.sum(mean, dim=-1) mean = torch.sum(mean, dim=-1)
# Average over the batch dimensions
mean = torch.mean(mean)
return mean return mean
...@@ -580,6 +597,9 @@ def tm_loss( ...@@ -580,6 +597,9 @@ def tm_loss(
(resolution >= min_resolution) & (resolution <= max_resolution) (resolution >= min_resolution) & (resolution <= max_resolution)
) )
# Average over the loss dimension
loss = torch.mean(loss)
return loss return loss
...@@ -1351,6 +1371,8 @@ def experimentally_resolved_loss( ...@@ -1351,6 +1371,8 @@ def experimentally_resolved_loss(
(resolution >= min_resolution) & (resolution <= max_resolution) (resolution >= min_resolution) & (resolution <= max_resolution)
) )
loss = torch.mean(loss)
return loss return loss
...@@ -1469,8 +1491,8 @@ class AlphaFoldLoss(nn.Module): ...@@ -1469,8 +1491,8 @@ class AlphaFoldLoss(nn.Module):
} }
cum_loss = 0 cum_loss = 0
for k, loss_fn in loss_fns.items(): for loss_name, loss_fn in loss_fns.items():
weight = self.config[k].weight weight = self.config[loss_name].weight
if weight: if weight:
loss = loss_fn() loss = loss_fn()
cum_loss = cum_loss + weight * loss cum_loss = cum_loss + weight * loss
......
...@@ -50,12 +50,10 @@ def main(args): ...@@ -50,12 +50,10 @@ def main(args):
model = model.to(args.model_device) model = model.to(args.model_device)
# FEATURE COLLECTION AND PROCESSING # FEATURE COLLECTION AND PROCESSING
num_ensemble = 1
template_featurizer = templates.TemplateHitFeaturizer( template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
max_hits=args.max_template_hits, max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path, kalign_binary_path=args.kalign_binary_path,
release_dates_path=None, release_dates_path=None,
obsolete_pdbs_path=args.obsolete_pdbs_path obsolete_pdbs_path=args.obsolete_pdbs_path
...@@ -85,7 +83,6 @@ def main(args): ...@@ -85,7 +83,6 @@ def main(args):
random_seed = args.data_random_seed random_seed = args.data_random_seed
if random_seed is None: if random_seed is None:
random_seed = random.randrange(sys.maxsize) random_seed = random.randrange(sys.maxsize)
config.data.predict.num_ensemble = num_ensemble
feature_processor = feature_pipeline.FeaturePipeline(config.data) feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base): if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base) os.makedirs(output_dir_base)
......
...@@ -40,9 +40,6 @@ def add_data_args(parser: argparse.ArgumentParser): ...@@ -40,9 +40,6 @@ def add_data_args(parser: argparse.ArgumentParser):
'--max_template_date', type=str, '--max_template_date', type=str,
default=date.today().strftime("%Y-%m-%d"), default=date.today().strftime("%Y-%m-%d"),
) )
parser.add_argument(
'--max_template_hits', type=int, default=20,
)
parser.add_argument( parser.add_argument(
'--obsolete_pdbs_path', type=str, default=None '--obsolete_pdbs_path', type=str, default=None
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment