"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "5218f45c2c3b254d7a82515b24cbac6902296b8f"
Commit 1cb267ed authored by Aditya Chetan's avatar Aditya Chetan Committed by Facebook Github Bot
Browse files

Fixing example of batched predictions for Roberta (#1195)

Summary:
For batched predictions in Roberta, the README was giving an example that was pretty unclear. After a thorough discussion with ngoyal2707 in issue https://github.com/pytorch/fairseq/issues/1167 he gave a clear example of how batched predictions were supposed to be done. Since I spent a lot of time on this inconsistency, I thought that it might benefit the community if his solution was in the official README 😄 !

For for details, see issue https://github.com/pytorch/fairseq/issues/1167
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1195

Differential Revision: D17639354

Pulled By: myleott

fbshipit-source-id: 3eb60c5804a6481f533b19073da7880dfd0d522d
parent 86857a58
......@@ -146,11 +146,26 @@ logprobs = roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.
##### Batched prediction:
```python
import torch
from fairseq.data.data_utils import collate_tokens
sentences = ['Hello world.', 'Another unrelated sentence.']
batch = collate_tokens([roberta.encode(sent) for sent in sentences], pad_idx=1)
logprobs = roberta.predict('new_task', batch)
assert logprobs.size() == torch.Size([2, 3])
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
roberta.eval()
batch_of_pairs = [
['Roberta is a heavily optimized version of BERT.', 'Roberta is not very optimized.'],
['Roberta is a heavily optimized version of BERT.', 'Roberta is based on BERT.'],
['potatoes are awesome.', 'I like to run.'],
['Mars is very far from earth.', 'Mars is very close.'],
]
batch = collate_tokens(
[roberta.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
)
logprobs = roberta.predict('mnli', batch)
print(logprobs.argmax(dim=1))
# tensor([0, 2, 1, 0])
```
##### Using the GPU:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment