"examples/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "30c5a79f9071a4b45692847a1a1fba4ebee3f6eb"
Commit 79f9f03d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix distillation bug

parent b32bfeec
...@@ -208,6 +208,10 @@ and supports the full range of training options that entails, including ...@@ -208,6 +208,10 @@ and supports the full range of training options that entails, including
multi-node distributed training. For more information, consult PyTorch multi-node distributed training. For more information, consult PyTorch
Lightning documentation and the `--help` flag of the training script. Lightning documentation and the `--help` flag of the training script.
Note that the data directory can also contain PDB files previously output by
the model. These are treated as members of the self-distillation set and are
subjected to distillation-set-only preprocessing steps.
## Testing ## Testing
To run unit tests, use To run unit tests, use
......
...@@ -176,7 +176,7 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion): ...@@ -176,7 +176,7 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
@curry1 @curry1
def sample_msa(protein, max_seq, keep_extra, seed=None): def sample_msa(protein, max_seq, keep_extra, seed=None):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.""" """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
num_seq = protein["msa"].shape[0] num_seq = protein["msa"].shape[0]
g = torch.Generator(device=protein["msa"].device) g = torch.Generator(device=protein["msa"].device)
if seed is not None: if seed is not None:
...@@ -202,7 +202,7 @@ def sample_msa(protein, max_seq, keep_extra, seed=None): ...@@ -202,7 +202,7 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
@curry1 @curry1
def sample_msa_distillation(protein, max_seq): def sample_msa_distillation(protein, max_seq):
if(protein["is_distillation"] == 1): if(protein["is_distillation"] == 1):
protein = sample_msa(protein, max_seq, keep_extra=False) protein = sample_msa(max_seq, keep_extra=False)(protein)
return protein return protein
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment