Unverified Commit 2dc080ce authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #374 from aqlaboratory/msa-block-deletion

Fix for MSA block deletion
parents a64f1a29 b935639b
......@@ -156,10 +156,12 @@ def model_config(
elif name == "seqemb_initial_training":
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.block_delete_msa = False
c.data.train.max_distillation_msa_clusters = 1
elif name == "seqemb_finetuning":
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.block_delete_msa = False
c.data.train.max_distillation_msa_clusters = 1
c.data.train.crop_size = 384
c.loss.violation.weight = 1.
......@@ -311,6 +313,11 @@ config = mlc.ConfigDict(
"true_msa": [NUM_MSA_SEQ, NUM_RES],
"use_clamped_fape": [],
},
"block_delete_msa": {
"msa_fraction_per_block": 0.3,
"randomize_num_blocks": False,
"num_blocks": 5,
},
"masked_msa": {
"profile_prob": 0.1,
"same_prob": 0.1,
......@@ -355,6 +362,7 @@ config = mlc.ConfigDict(
"predict": {
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"block_delete_msa": False,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512,
"max_extra_msa": 1024,
......@@ -368,6 +376,7 @@ config = mlc.ConfigDict(
"eval": {
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"block_delete_msa": False,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_extra_msa": 1024,
......@@ -381,6 +390,7 @@ config = mlc.ConfigDict(
"train": {
"fixed_size": True,
"subsample_templates": True,
"block_delete_msa": True,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_extra_msa": 1024,
......
......@@ -253,28 +253,33 @@ def block_delete_msa(protein, config):
* config.msa_fraction_per_block
).to(torch.int32)
if int(block_num_seq) == 0:
return protein
if config.randomize_num_blocks:
nb = torch.distributions.uniform.Uniform(
0, config.num_blocks + 1
).sample()
nb = int(torch.randint(
low=0,
high=config.num_blocks + 1,
size=(1,),
device=protein["msa"].device,
)[0])
else:
nb = config.num_blocks
del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]
del_block_starts = torch.randint(low=1, high=num_seq, size=(nb,), device=protein["msa"].device)
del_blocks = del_block_starts[:, None] + torch.arange(start=0, end=block_num_seq)
del_blocks = torch.clip(del_blocks, 1, num_seq - 1)
del_indices = torch.unique(torch.reshape(del_blocks, [-1]))
# Make sure we keep the original sequence
combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None]))
combined = torch.cat((torch.arange(start=0, end=num_seq), del_indices)).long()
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
intersection = uniques[counts > 1]
keep_indices = torch.squeeze(difference, 0)
keep_indices = uniques[counts == 1]
assert int(keep_indices[0]) == 0
for k in MSA_FEATURE_NAMES:
if k in protein:
protein[k] = torch.gather(protein[k], keep_indices)
protein[k] = torch.index_select(protein[k], 0, keep_indices)
return protein
......
......@@ -71,6 +71,9 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms = []
if mode_cfg.block_delete_msa:
transforms.append(data_transforms.block_delete_msa(common_cfg.block_delete_msa))
if "max_distillation_msa_clusters" in mode_cfg:
transforms.append(
data_transforms.sample_msa_distillation(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment