You need to sign in or sign up before continuing.
Commit 3e3fe722 authored by Roman Rädle's avatar Roman Rädle Committed by Facebook Github Bot
Browse files

Return predicted token for RoBERTa filling mask

Summary:
Added the `predicted_token` to each `topk` filled output item

Updated RoBERTa filling mask example in README.md

Reviewed By: myleott

Differential Revision: D17188810

fbshipit-source-id: 5fdc57ff2c13239dabf13a8dad43ae9a55e8931c
parent 1566cfb9
...@@ -167,13 +167,13 @@ RoBERTa can be used to fill `<mask>` tokens in the input. Some examples from the ...@@ -167,13 +167,13 @@ RoBERTa can be used to fill `<mask>` tokens in the input. Some examples from the
[Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/): [Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/):
```python ```python
roberta.fill_mask('The first Star wars movie came out in <mask>', topk=3) roberta.fill_mask('The first Star wars movie came out in <mask>', topk=3)
# [('The first Star wars movie came out in 1977', 0.9504712224006653), ('The first Star wars movie came out in 1978', 0.009986752644181252), ('The first Star wars movie came out in 1979', 0.00957468245178461)] # [('The first Star wars movie came out in 1977', 0.9504708051681519, ' 1977'), ('The first Star wars movie came out in 1978', 0.009986862540245056, ' 1978'), ('The first Star wars movie came out in 1979', 0.009574787691235542, ' 1979')]
roberta.fill_mask('Vikram samvat calender is official in <mask>', topk=3) roberta.fill_mask('Vikram samvat calender is official in <mask>', topk=3)
# [('Vikram samvat calender is official in India', 0.21878768503665924), ('Vikram samvat calender is official in Delhi', 0.08547217398881912), ('Vikram samvat calender is official in Gujarat', 0.07556255906820297)] # [('Vikram samvat calender is official in India', 0.21878819167613983, ' India'), ('Vikram samvat calender is official in Delhi', 0.08547237515449524, ' Delhi'), ('Vikram samvat calender is official in Gujarat', 0.07556215673685074, ' Gujarat')]
roberta.fill_mask('<mask> is the common currency of the European Union', topk=3) roberta.fill_mask('<mask> is the common currency of the European Union', topk=3)
# [('Euro is the common currency of the European Union', 0.945650577545166), ('euro is the common currency of the European Union', 0.025747718289494514), ('€ is the common currency of the European Union', 0.011183015070855618)] # [('Euro is the common currency of the European Union', 0.9456493854522705, 'Euro'), ('euro is the common currency of the European Union', 0.025748178362846375, 'euro'), ('€ is the common currency of the European Union', 0.011183084920048714, '€')]
``` ```
#### Pronoun disambiguation (Winograd Schema Challenge): #### Pronoun disambiguation (Winograd Schema Challenge):
......
...@@ -174,11 +174,13 @@ class RobertaHubInterface(nn.Module): ...@@ -174,11 +174,13 @@ class RobertaHubInterface(nn.Module):
' {0}'.format(masked_token), predicted_token ' {0}'.format(masked_token), predicted_token
), ),
values[index].item(), values[index].item(),
predicted_token,
)) ))
else: else:
topk_filled_outputs.append(( topk_filled_outputs.append((
masked_input.replace(masked_token, predicted_token), masked_input.replace(masked_token, predicted_token),
values[index].item(), values[index].item(),
predicted_token,
)) ))
return topk_filled_outputs return topk_filled_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment