Commit 3e3fe722 authored by Roman Rädle's avatar Roman Rädle Committed by Facebook Github Bot
Browse files

Return predicted token for RoBERTa filling mask

Summary:
Added the `predicted_token` to each `topk` filled output item

Updated RoBERTa filling mask example in README.md

Reviewed By: myleott

Differential Revision: D17188810

fbshipit-source-id: 5fdc57ff2c13239dabf13a8dad43ae9a55e8931c
parent 1566cfb9
...@@ -167,13 +167,13 @@ RoBERTa can be used to fill `<mask>` tokens in the input. Some examples from the ...@@ -167,13 +167,13 @@ RoBERTa can be used to fill `<mask>` tokens in the input. Some examples from the
[Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/): [Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/):
```python ```python
roberta.fill_mask('The first Star wars movie came out in <mask>', topk=3) roberta.fill_mask('The first Star wars movie came out in <mask>', topk=3)
# [('The first Star wars movie came out in 1977', 0.9504712224006653), ('The first Star wars movie came out in 1978', 0.009986752644181252), ('The first Star wars movie came out in 1979', 0.00957468245178461)] # [('The first Star wars movie came out in 1977', 0.9504708051681519, ' 1977'), ('The first Star wars movie came out in 1978', 0.009986862540245056, ' 1978'), ('The first Star wars movie came out in 1979', 0.009574787691235542, ' 1979')]
roberta.fill_mask('Vikram samvat calender is official in <mask>', topk=3) roberta.fill_mask('Vikram samvat calender is official in <mask>', topk=3)
# [('Vikram samvat calender is official in India', 0.21878768503665924), ('Vikram samvat calender is official in Delhi', 0.08547217398881912), ('Vikram samvat calender is official in Gujarat', 0.07556255906820297)] # [('Vikram samvat calender is official in India', 0.21878819167613983, ' India'), ('Vikram samvat calender is official in Delhi', 0.08547237515449524, ' Delhi'), ('Vikram samvat calender is official in Gujarat', 0.07556215673685074, ' Gujarat')]
roberta.fill_mask('<mask> is the common currency of the European Union', topk=3) roberta.fill_mask('<mask> is the common currency of the European Union', topk=3)
# [('Euro is the common currency of the European Union', 0.945650577545166), ('euro is the common currency of the European Union', 0.025747718289494514), ('€ is the common currency of the European Union', 0.011183015070855618)] # [('Euro is the common currency of the European Union', 0.9456493854522705, 'Euro'), ('euro is the common currency of the European Union', 0.025748178362846375, 'euro'), ('€ is the common currency of the European Union', 0.011183084920048714, '€')]
``` ```
#### Pronoun disambiguation (Winograd Schema Challenge): #### Pronoun disambiguation (Winograd Schema Challenge):
......
...@@ -174,11 +174,13 @@ class RobertaHubInterface(nn.Module): ...@@ -174,11 +174,13 @@ class RobertaHubInterface(nn.Module):
' {0}'.format(masked_token), predicted_token ' {0}'.format(masked_token), predicted_token
), ),
values[index].item(), values[index].item(),
predicted_token,
)) ))
else: else:
topk_filled_outputs.append(( topk_filled_outputs.append((
masked_input.replace(masked_token, predicted_token), masked_input.replace(masked_token, predicted_token),
values[index].item(), values[index].item(),
predicted_token,
)) ))
return topk_filled_outputs return topk_filled_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment