Commit 0026173e authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Cleaned up `precompute_embeddings.py`.

parent bcc6d97b
......@@ -95,8 +95,7 @@ def main(args):
dataset, collate_fn=alphabet.get_batch_converter(), batch_sampler=batches
)
logging.info("Loaded all sequences")
assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in args.repr_layers)
repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in args.repr_layers]
repr_layers = [33]
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
......@@ -118,7 +117,6 @@ def main(args):
os.makedirs(os.path.join(args.output_dir, label), exist_ok=True)
result = {"label": label}
if "per_tok" in args.include:
result["representations"] = {
layer: t[i, 1: len(strs[i]) + 1].clone()
for layer, t in representations.items()
......@@ -146,15 +144,6 @@ if __name__ == "__main__":
"--toks_per_batch", type=int, default=4096,
help="maximum tokens in a batch"
)
parser.add_argument(
"--repr_layers", type=int, default=[-1], nargs="+",
help="Layer indices from which to extract representations (0 to num_layers, inclusive)"
)
parser.add_argument(
"--include", type=str, default=["per_tok"], nargs="+",
choices=["mean", "per_tok", "bos", "contacts"],
help="Specify which representations to return"
)
parser.add_argument(
"--truncate", action="store_true", default=True,
help="Truncate sequences longer than 1022 (ESM restriction). Default: True"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment