Unverified Commit 1c3ab3e5 authored by Alberto Villa's avatar Alberto Villa Committed by GitHub
Browse files

Typo in usage example, changed to device instead of torch_device (#11979)

parent 47a98fc4
...@@ -90,7 +90,7 @@ Usage Example ...@@ -90,7 +90,7 @@ Usage Example
>>> device = 'cuda' if torch.cuda.is_available() else 'cpu' >>> device = 'cuda' if torch.cuda.is_available() else 'cpu'
>>> tokenizer = PegasusTokenizer.from_pretrained(model_name) >>> tokenizer = PegasusTokenizer.from_pretrained(model_name)
>>> model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device) >>> model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
>>> batch = tokenizer(src_text, truncation=True, padding='longest', return_tensors="pt").to(torch_device) >>> batch = tokenizer(src_text, truncation=True, padding='longest', return_tensors="pt").to(device)
>>> translated = model.generate(**batch) >>> translated = model.generate(**batch)
>>> tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) >>> tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
>>> assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers." >>> assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment