Commit 92835fd5 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

More cleaning of bulk embedding generation script

parent 0026173e
......@@ -110,7 +110,7 @@ def main(args):
logits = out["logits"].to(device="cpu")
representations = {
layer: t.to(device="cpu") for layer, t in out["representations"].items()
33: out["representations"][33].to(device="cpu")
}
for i, label in enumerate(labels):
......@@ -118,8 +118,7 @@ def main(args):
result = {"label": label}
result["representations"] = {
layer: t[i, 1: len(strs[i]) + 1].clone()
for layer, t in representations.items()
33: representations[33][i, 1: len(strs[i]) + 1].clone()
}
torch.save(
result,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment