"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "02ef825be208badb6bb7bf0641e7035406690b18"
Unverified Commit 3aca02ef authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Bart example: model.to(device) (#3194)

parent 5164ea91
...@@ -18,7 +18,7 @@ def chunks(lst, n): ...@@ -18,7 +18,7 @@ def chunks(lst, n):
def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
fout = Path(out_file).open("w") fout = Path(out_file).open("w")
model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,) model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device)
tokenizer = BartTokenizer.from_pretrained("bart-large") tokenizer = BartTokenizer.from_pretrained("bart-large")
for batch in tqdm(list(chunks(lns, batch_size))): for batch in tqdm(list(chunks(lns, batch_size))):
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment