Commit 0026173e authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Cleaned up `precompute_embeddings.py`.

parent bcc6d97b
...@@ -95,8 +95,7 @@ def main(args): ...@@ -95,8 +95,7 @@ def main(args):
dataset, collate_fn=alphabet.get_batch_converter(), batch_sampler=batches dataset, collate_fn=alphabet.get_batch_converter(), batch_sampler=batches
) )
logging.info("Loaded all sequences") logging.info("Loaded all sequences")
assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in args.repr_layers) repr_layers = [33]
repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in args.repr_layers]
with torch.no_grad(): with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader): for batch_idx, (labels, strs, toks) in enumerate(data_loader):
...@@ -118,11 +117,10 @@ def main(args): ...@@ -118,11 +117,10 @@ def main(args):
os.makedirs(os.path.join(args.output_dir, label), exist_ok=True) os.makedirs(os.path.join(args.output_dir, label), exist_ok=True)
result = {"label": label} result = {"label": label}
if "per_tok" in args.include: result["representations"] = {
result["representations"] = { layer: t[i, 1: len(strs[i]) + 1].clone()
layer: t[i, 1: len(strs[i]) + 1].clone() for layer, t in representations.items()
for layer, t in representations.items() }
}
torch.save( torch.save(
result, result,
os.path.join(args.output_dir, label, label+".pt") os.path.join(args.output_dir, label, label+".pt")
...@@ -146,15 +144,6 @@ if __name__ == "__main__": ...@@ -146,15 +144,6 @@ if __name__ == "__main__":
"--toks_per_batch", type=int, default=4096, "--toks_per_batch", type=int, default=4096,
help="maximum tokens in a batch" help="maximum tokens in a batch"
) )
parser.add_argument(
"--repr_layers", type=int, default=[-1], nargs="+",
help="Layer indices from which to extract representations (0 to num_layers, inclusive)"
)
parser.add_argument(
"--include", type=str, default=["per_tok"], nargs="+",
choices=["mean", "per_tok", "bos", "contacts"],
help="Specify which representations to return"
)
parser.add_argument( parser.add_argument(
"--truncate", action="store_true", default=True, "--truncate", action="store_true", default=True,
help="Truncate sequences longer than 1022 (ESM restriction). Default: True" help="Truncate sequences longer than 1022 (ESM restriction). Default: True"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment