"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "cdcc01be0ead8e3473ff88b95b8c53755a60750f"
Commit 1cb267ed authored by Aditya Chetan's avatar Aditya Chetan Committed by Facebook Github Bot
Browse files

Fixing example of batched predictions for Roberta (#1195)

Summary:
For batched predictions in Roberta, the README was giving an example that was pretty unclear. After a thorough discussion with ngoyal2707 in issue https://github.com/pytorch/fairseq/issues/1167 he gave a clear example of how batched predictions were supposed to be done. Since I spent a lot of time on this inconsistency, I thought that it might benefit the community if his solution was in the official README 😄 !

For for details, see issue https://github.com/pytorch/fairseq/issues/1167
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1195

Differential Revision: D17639354

Pulled By: myleott

fbshipit-source-id: 3eb60c5804a6481f533b19073da7880dfd0d522d
parent 86857a58
...@@ -146,11 +146,26 @@ logprobs = roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1. ...@@ -146,11 +146,26 @@ logprobs = roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.
##### Batched prediction: ##### Batched prediction:
```python ```python
import torch
from fairseq.data.data_utils import collate_tokens from fairseq.data.data_utils import collate_tokens
sentences = ['Hello world.', 'Another unrelated sentence.']
batch = collate_tokens([roberta.encode(sent) for sent in sentences], pad_idx=1) roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
logprobs = roberta.predict('new_task', batch) roberta.eval()
assert logprobs.size() == torch.Size([2, 3])
batch_of_pairs = [
['Roberta is a heavily optimized version of BERT.', 'Roberta is not very optimized.'],
['Roberta is a heavily optimized version of BERT.', 'Roberta is based on BERT.'],
['potatoes are awesome.', 'I like to run.'],
['Mars is very far from earth.', 'Mars is very close.'],
]
batch = collate_tokens(
[roberta.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
)
logprobs = roberta.predict('mnli', batch)
print(logprobs.argmax(dim=1))
# tensor([0, 2, 1, 0])
``` ```
##### Using the GPU: ##### Using the GPU:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment