Commit b935639b authored by Christina Floristean's avatar Christina Floristean
Browse files

Fix bugs in block deletion, disable for soloseq

parent 616de3df
......@@ -156,10 +156,12 @@ def model_config(
elif name == "seqemb_initial_training":
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.block_delete_msa = False
c.data.train.max_distillation_msa_clusters = 1
elif name == "seqemb_finetuning":
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.block_delete_msa = False
c.data.train.max_distillation_msa_clusters = 1
c.data.train.crop_size = 384
c.loss.violation.weight = 1.
......
......@@ -253,28 +253,33 @@ def block_delete_msa(protein, config):
* config.msa_fraction_per_block
).to(torch.int32)
if int(block_num_seq) == 0:
return protein
if config.randomize_num_blocks:
nb = torch.distributions.uniform.Uniform(
0, config.num_blocks + 1
).sample()
nb = int(torch.randint(
low=0,
high=config.num_blocks + 1,
size=(1,),
device=protein["msa"].device,
)[0])
else:
nb = config.num_blocks
del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]
del_block_starts = torch.randint(low=1, high=num_seq, size=(nb,), device=protein["msa"].device)
del_blocks = del_block_starts[:, None] + torch.arange(start=0, end=block_num_seq)
del_blocks = torch.clip(del_blocks, 1, num_seq - 1)
del_indices = torch.unique(torch.reshape(del_blocks, [-1]))
# Make sure we keep the original sequence
combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None]))
combined = torch.cat((torch.arange(start=0, end=num_seq), del_indices)).long()
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
intersection = uniques[counts > 1]
keep_indices = torch.squeeze(difference, 0)
keep_indices = uniques[counts == 1]
assert int(keep_indices[0]) == 0
for k in MSA_FEATURE_NAMES:
if k in protein:
protein[k] = torch.gather(protein[k], keep_indices)
protein[k] = torch.index_select(protein[k], 0, keep_indices)
return protein
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment