Commit 83249196 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add WSC task and criterion

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/1004

Differential Revision: D16751443

Pulled By: myleott

fbshipit-source-id: f70acd6c7be6d69da45b5b32fe4c4eff021539ab
parent a00ce132
...@@ -12,7 +12,8 @@ Model | Description | # params | Download ...@@ -12,7 +12,8 @@ Model | Description | # params | Download
---|---|---|--- ---|---|---|---
`roberta.base` | RoBERTa using the BERT-base architecture | 125M | [roberta.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz) `roberta.base` | RoBERTa using the BERT-base architecture | 125M | [roberta.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz)
`roberta.large` | RoBERTa using the BERT-large architecture | 355M | [roberta.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz) `roberta.large` | RoBERTa using the BERT-large architecture | 355M | [roberta.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz)
`roberta.large.mnli` | `roberta.large` finetuned on MNLI | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz) `roberta.large.mnli` | `roberta.large` finetuned on [MNLI](http://www.nyu.edu/projects/bowman/multinli) | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz)
`roberta.large.wsc` | `roberta.large` finetuned on [WSC](https://cs.nyu.edu/faculty/davise/papers/WinogradSchemas/WS.html) | 355M | [roberta.large.wsc.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz)
## Results ## Results
...@@ -24,12 +25,12 @@ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B ...@@ -24,12 +25,12 @@ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4 `roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
`roberta.large.mnli` | 90.2 | - | - | - | - | - | - | - `roberta.large.mnli` | 90.2 | - | - | - | - | - | - | -
##### Results on SuperGLUE tasks (dev set, single model, single-task finetuning) ##### Results on SuperGLUE tasks (dev set, single model, single-task finetuning)
Model | BoolQ | CB | COPA | MultiRC | RTE | WiC | WSC Model | BoolQ | CB | COPA | MultiRC | RTE | WiC | WSC
---|---|---|---|---|---|---|--- ---|---|---|---|---|---|---|---
`roberta.large` | 86.9 | 98.2 | 94.0 | 85.7 | 89.5 | 75.6 | 91.3 `roberta.large` | 86.9 | 98.2 | 94.0 | 85.7 | 89.5 | 75.6 | -
`roberta.large.wsc` | - | - | - | - | - | - | 91.3
##### Results on SQuAD (dev set) ##### Results on SQuAD (dev set)
...@@ -83,28 +84,6 @@ assert len(all_layers) == 25 ...@@ -83,28 +84,6 @@ assert len(all_layers) == 25
assert torch.all(all_layers[-1] == last_layer_features) assert torch.all(all_layers[-1] == last_layer_features)
``` ```
By default RoBERTa outputs one feature vector per BPE token. You can instead
realign the features to match [spaCy's word-level tokenization](https://spacy.io/usage/linguistic-features#tokenization)
with the `extract_features_aligned_to_words` method. This will compute a
weighted average of the BPE-level features for each word and expose them in
spaCy's `Token.vector` attribute:
```python
doc = roberta.extract_features_aligned_to_words('I said, "hello RoBERTa."')
assert len(doc) == 10
for tok in doc:
print('{:10}{} (...)'.format(str(tok), tok.vector[:5]))
# <s> tensor([-0.1316, -0.0386, -0.0832, -0.0477, 0.1943], grad_fn=<SliceBackward>) (...)
# I tensor([ 0.0559, 0.1541, -0.4832, 0.0880, 0.0120], grad_fn=<SliceBackward>) (...)
# said tensor([-0.1565, -0.0069, -0.8915, 0.0501, -0.0647], grad_fn=<SliceBackward>) (...)
# , tensor([-0.1318, -0.0387, -0.0834, -0.0477, 0.1944], grad_fn=<SliceBackward>) (...)
# " tensor([-0.0486, 0.1818, -0.3946, -0.0553, 0.0981], grad_fn=<SliceBackward>) (...)
# hello tensor([ 0.0079, 0.1799, -0.6204, -0.0777, -0.0923], grad_fn=<SliceBackward>) (...)
# RoBERTa tensor([-0.2339, -0.1184, -0.7343, -0.0492, 0.5829], grad_fn=<SliceBackward>) (...)
# . tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=<SliceBackward>) (...)
# " tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=<SliceBackward>) (...)
# </s> tensor([-0.0930, -0.0392, -0.0821, 0.0158, 0.0649], grad_fn=<SliceBackward>) (...)
```
##### Use RoBERTa for sentence-pair classification tasks: ##### Use RoBERTa for sentence-pair classification tasks:
```python ```python
# Download RoBERTa already finetuned for MNLI # Download RoBERTa already finetuned for MNLI
...@@ -141,22 +120,79 @@ roberta.cuda() ...@@ -141,22 +120,79 @@ roberta.cuda()
roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=<LogSoftmaxBackward>) roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)
``` ```
##### Filling mask: ## Advanced usage
Some examples from the [Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/).
#### Filling masks:
RoBERTa can be used to fill `<mask>` tokens in the input. Some examples from the
[Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/):
```python ```python
>>> roberta.fill_mask("The first Star wars movie came out in <mask>", topk=3) roberta.fill_mask('The first Star wars movie came out in <mask>', topk=3)
[('The first Star wars movie came out in 1977', 0.9504712224006653), ('The first Star wars movie came out in 1978', 0.009986752644181252), ('The first Star wars movie came out in 1979', 0.00957468245178461)] # [('The first Star wars movie came out in 1977', 0.9504712224006653), ('The first Star wars movie came out in 1978', 0.009986752644181252), ('The first Star wars movie came out in 1979', 0.00957468245178461)]
roberta.fill_mask('Vikram samvat calender is official in <mask>', topk=3)
# [('Vikram samvat calender is official in India', 0.21878768503665924), ('Vikram samvat calender is official in Delhi', 0.08547217398881912), ('Vikram samvat calender is official in Gujarat', 0.07556255906820297)]
roberta.fill_mask('<mask> is the common currency of the European Union', topk=3)
# [('Euro is the common currency of the European Union', 0.945650577545166), ('euro is the common currency of the European Union', 0.025747718289494514), ('€ is the common currency of the European Union', 0.011183015070855618)]
```
>>> roberta.fill_mask("Vikram samvat calender is official in <mask>", topk=3) #### Pronoun disambiguation (Winograd Schema Challenge):
[('Vikram samvat calender is official in India', 0.21878768503665924), ('Vikram samvat calender is official in Delhi', 0.08547217398881912), ('Vikram samvat calender is official in Gujarat', 0.07556255906820297)]
>>> roberta.fill_mask("<mask> is the common currency of the European Union", topk=3) RoBERTa can be used to disambiguate pronouns. First install spaCy and download the English-language model:
[('Euro is the common currency of the European Union', 0.945650577545166), ('euro is the common currency of the European Union', 0.025747718289494514), ('€ is the common currency of the European Union', 0.011183015070855618)] ```bash
pip install spacy
python -m spacy download en_core_web_lg
```
Next load the `roberta.large.wsc` model and call the `disambiguate_pronoun`
function. The pronoun should be surrounded by square brackets (`[]`) and the
query referent surrounded by underscores (`_`), or left blank to return the
predicted candidate text directly:
```python
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.wsc', user_dir='examples/roberta/wsc')
roberta.cuda() # use the GPU (optional)
roberta.disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.')
# True
roberta.disambiguate_pronoun('The trophy would not fit in the brown _suitcase_ because [it] was too big.')
# False
roberta.disambiguate_pronoun('The city councilmen refused the demonstrators a permit because [they] feared violence.')
# 'The city councilmen'
roberta.disambiguate_pronoun('The city councilmen refused the demonstrators a permit because [they] advocated violence.')
# 'demonstrators'
```
See the [RoBERTA Winograd Schema Challenge (WSC) README](README.wsc.md) for more details on how to train this model.
#### Extract features aligned to words:
By default RoBERTa outputs one feature vector per BPE token. You can instead
realign the features to match [spaCy's word-level tokenization](https://spacy.io/usage/linguistic-features#tokenization)
with the `extract_features_aligned_to_words` method. This will compute a
weighted average of the BPE-level features for each word and expose them in
spaCy's `Token.vector` attribute:
```python
doc = roberta.extract_features_aligned_to_words('I said, "hello RoBERTa."')
assert len(doc) == 10
for tok in doc:
print('{:10}{} (...)'.format(str(tok), tok.vector[:5]))
# <s> tensor([-0.1316, -0.0386, -0.0832, -0.0477, 0.1943], grad_fn=<SliceBackward>) (...)
# I tensor([ 0.0559, 0.1541, -0.4832, 0.0880, 0.0120], grad_fn=<SliceBackward>) (...)
# said tensor([-0.1565, -0.0069, -0.8915, 0.0501, -0.0647], grad_fn=<SliceBackward>) (...)
# , tensor([-0.1318, -0.0387, -0.0834, -0.0477, 0.1944], grad_fn=<SliceBackward>) (...)
# " tensor([-0.0486, 0.1818, -0.3946, -0.0553, 0.0981], grad_fn=<SliceBackward>) (...)
# hello tensor([ 0.0079, 0.1799, -0.6204, -0.0777, -0.0923], grad_fn=<SliceBackward>) (...)
# RoBERTa tensor([-0.2339, -0.1184, -0.7343, -0.0492, 0.5829], grad_fn=<SliceBackward>) (...)
# . tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=<SliceBackward>) (...)
# " tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=<SliceBackward>) (...)
# </s> tensor([-0.0930, -0.0392, -0.0821, 0.0158, 0.0649], grad_fn=<SliceBackward>) (...)
``` ```
##### Evaluating the `roberta.large.mnli` model #### Evaluating the `roberta.large.mnli` model:
Example python code snippet to evaluate accuracy on the MNLI dev_matched set. Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
```python ```python
label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'} label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
ncorrect, nsamples = 0, 0 ncorrect, nsamples = 0, 0
...@@ -181,6 +217,7 @@ print('| Accuracy: ', float(ncorrect)/float(nsamples)) ...@@ -181,6 +217,7 @@ print('| Accuracy: ', float(ncorrect)/float(nsamples))
- [Finetuning on GLUE](README.finetune_glue.md) - [Finetuning on GLUE](README.finetune_glue.md)
- [Finetuning on custom classification tasks (e.g., IMDB)](README.finetune_custom_classification.md) - [Finetuning on custom classification tasks (e.g., IMDB)](README.finetune_custom_classification.md)
- [Finetuning on Winograd Schema Challenge (WSC)](README.wsc.md)
- Finetuning on SQuAD: coming soon - Finetuning on SQuAD: coming soon
## Pretraining using your own data ## Pretraining using your own data
......
# Finetuning RoBERTa on Winograd Schema Challenge (WSC) data
The following instructions can be used to finetune RoBERTa on the WSC training
data provided by [SuperGLUE](https://super.gluebenchmark.com/).
Note that there is high variance in the results. For our GLUE/SuperGLUE
submission we swept over the learning rate, batch size and total number of
updates, as well as the random seed. Out of ~100 runs we chose the best 7 models
and ensembled them.
**Note:** The instructions below use a slightly different loss function than
what's described in the original RoBERTa arXiv paper. In particular,
[Kocijan et al. (2019)](https://arxiv.org/abs/1905.06290) introduce a margin
ranking loss between `(query, candidate)` pairs with tunable hyperparameters
alpha and beta. This is supported in our code as well with the `--wsc-alpha` and
`--wsc-beta` arguments. However, we achieved slightly better (and more robust)
results on the development set by instead using a single cross entropy loss term
over the log-probabilities for the query and all candidates. This reduces the
number of hyperparameters and our best model achieved 92.3% development set
accuracy, compared to ~90% accuracy for the margin loss. Later versions of the
RoBERTa arXiv paper will describe this updated formulation.
### 1) Download the WSC data from the SuperGLUE website:
```bash
wget https://dl.fbaipublicfiles.com/glue/superglue/data/v2/WSC.zip
unzip WSC.zip
# we also need to copy the RoBERTa dictionary into the same directory
wget -O WSC/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
```
### 2) Finetune over the provided training data:
```bash
TOTAL_NUM_UPDATES=2000 # Total number of training steps.
WARMUP_UPDATES=250 # Linearly increase LR over this many steps.
LR=2e-05 # Peak LR for polynomial LR scheduler.
MAX_SENTENCES=16 # Batch size per GPU.
SEED=1 # Random seed.
ROBERTA_PATH=/path/to/roberta/model.pt
# we use the --user-dir option to load the task and criterion
# from the examples/roberta/wsc directory:
FAIRSEQ_PATH=/path/to/fairseq
FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/wsc
cd fairseq
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train WSC/ \
--restore-file $ROBERTA_PATH \
--reset-optimizer --reset-dataloader --reset-meters \
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
--valid-subset val \
--fp16 --ddp-backend no_c10d \
--user-dir $FAIRSEQ_USER_DIR \
--task wsc --criterion wsc --wsc-cross-entropy \
--arch roberta_large --bpe gpt2 --max-positions 512 \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
--lr-scheduler polynomial_decay --lr $LR \
--warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \
--max-sentences $MAX_SENTENCES \
--max-update $TOTAL_NUM_UPDATES \
--log-format simple --log-interval 100
```
The above command assumes training on 4 GPUs, but you can achieve the same
results on a single GPU by adding `--update-freq=4`.
### 3) Evaluate
```python
from fairseq.models.roberta import RobertaModel
from examples.roberta.wsc import wsc_utils # also loads WSC task and criterion
roberta = RobertaModel.from_pretrained('checkpoints', 'checkpoint_best.pt', 'WSC/')
roberta.cuda()
nsamples, ncorrect = 0, 0
for sentence, label in wsc_utils.jsonl_iterator('WSC/val.jsonl', eval=True):
pred = roberta.disambiguate_pronoun(sentence)
nsamples += 1
if pred == label:
ncorrect += 1
print('Accuracy: ' + str(ncorrect / float(nsamples)))
# Accuracy: 0.9230769230769231
```
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import wsc_criterion # noqa
from . import wsc_task # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import encoders
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion('wsc')
class WSCCriterion(FairseqCriterion):
def __init__(self, args, task):
super().__init__(args, task)
if self.args.save_predictions is not None:
self.prediction_h = open(self.args.save_predictions, 'w')
else:
self.prediction_h = None
self.bpe = encoders.build_bpe(args)
self.tokenizer = encoders.build_tokenizer(args)
def __del__(self):
if self.prediction_h is not None:
self.prediction_h.close()
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
parser.add_argument('--wsc-margin-alpha', type=float, metavar='A', default=1.0)
parser.add_argument('--wsc-margin-beta', type=float, metavar='B', default=0.0)
parser.add_argument('--wsc-cross-entropy', action='store_true',
help='use cross entropy formulation instead of margin loss')
parser.add_argument('--save-predictions', metavar='FILE',
help='file to save predictions to')
def forward(self, model, sample, reduce=True):
def get_masked_input(tokens, mask):
masked_tokens = tokens.clone()
masked_tokens[mask] = self.task.mask
return masked_tokens
def get_lprobs(tokens, mask):
logits, _ = model(src_tokens=get_masked_input(tokens, mask))
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
mask = mask.type_as(scores)
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
return scores
# compute loss and accuracy
loss, nloss = 0., 0
ncorrect, nqueries = 0, 0
for i, label in enumerate(sample['labels']):
query_lprobs = get_lprobs(
sample['query_tokens'][i].unsqueeze(0),
sample['query_masks'][i].unsqueeze(0),
)
cand_lprobs = get_lprobs(
sample['candidate_tokens'][i],
sample['candidate_masks'][i],
)
pred = (query_lprobs >= cand_lprobs).all().item()
if label is not None:
label = 1 if label else 0
ncorrect += 1 if pred == label else 0
nqueries += 1
if label:
# only compute a loss for positive instances
nloss += 1
if self.args.wsc_cross_entropy:
loss += F.cross_entropy(
torch.cat([query_lprobs, cand_lprobs]).unsqueeze(0),
query_lprobs.new([0]).long(),
)
else:
loss += (
- query_lprobs
+ self.args.wsc_margin_alpha * (
cand_lprobs - query_lprobs + self.args.wsc_margin_beta
).clamp(min=0)
).sum()
id = sample['id'][i].item()
if self.prediction_h is not None:
print('{}\t{}\t{}'.format(id, pred, label), file=self.prediction_h)
if nloss == 0:
loss = torch.tensor(0.0, requires_grad=True)
sample_size = nqueries if nqueries > 0 else 1
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['nsentences'],
'sample_size': sample_size,
'ncorrect': ncorrect,
'nqueries': nqueries,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs)
nqueries = sum(log.get('nqueries', 0) for log in logging_outputs)
if nqueries > 0:
agg_output['accuracy'] = ncorrect / float(nqueries)
return agg_output
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import tempfile
import numpy as np
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import (
data_utils,
Dictionary,
encoders,
IdDataset,
ListDataset,
NestedDictionaryDataset,
NumSamplesDataset,
NumelDataset,
SortDataset,
)
from fairseq.tasks import FairseqTask, register_task
from . import wsc_utils
@register_task('wsc')
class WSCTask(FairseqTask):
"""Task to finetune RoBERTa for Winograd Schemas."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR',
help='path to data directory; we load <split>.jsonl')
parser.add_argument('--init-token', type=int, default=None,
help='add token at the beginning of each batch item')
def __init__(self, args, vocab):
super().__init__(args)
self.vocab = vocab
self.mask = vocab.add_symbol('<mask>')
self.bpe = encoders.build_bpe(args)
self.tokenizer = encoders.build_tokenizer(args)
# hack to handle GPT-2 BPE, which includes leading spaces
if args.bpe == 'gpt2':
self.leading_space = True
self.trailing_space = False
else:
self.leading_space = False
self.trailing_space = True
@classmethod
def load_dictionary(cls, filename):
"""Load the dictionary from the filename
Args:
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
dictionary.add_symbol('<mask>')
return dictionary
@classmethod
def setup_task(cls, args, **kwargs):
assert args.criterion == 'wsc', 'Must set --criterion=wsc'
# load data and label dictionaries
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(vocab)))
return cls(args, vocab)
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
def binarize(s: str, append_eos: bool = False):
if self.tokenizer is not None:
s = self.tokenizer.encode(s)
if self.bpe is not None:
s = self.bpe.encode(s)
tokens = self.vocab.encode_line(
s, append_eos=append_eos, add_if_not_exist=False,
).long()
if self.args.init_token is not None:
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
return tokens
if data_path is None:
data_path = os.path.join(self.args.data, split + '.jsonl')
if not os.path.exists(data_path):
raise FileNotFoundError('Cannot find data: {}'.format(data_path))
query_tokens = []
query_masks = []
query_lengths = []
candidate_tokens = []
candidate_masks = []
candidate_lengths = []
labels = []
for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path):
prefix = sentence[:pronoun_span.start].text
suffix = sentence[pronoun_span.end:].text_with_ws
# spaCy spans include trailing spaces, but we need to know about
# leading spaces for the GPT-2 BPE
leading_space = ' ' if sentence[:pronoun_span.start].text_with_ws.endswith(' ') else ''
trailing_space = ' ' if pronoun_span.text_with_ws.endswith(' ') else ''
# get noun phrases, excluding pronouns and anything overlapping with the query
cand_spans = wsc_utils.filter_noun_chunks(
wsc_utils.extended_noun_chunks(sentence),
exclude_pronouns=True,
exclude_query=query,
exact_match=False,
)
def binarize_with_mask(txt):
toks = binarize(
prefix + leading_space + txt + trailing_space + suffix,
append_eos=True,
)
mask = torch.zeros_like(toks, dtype=torch.uint8)
mask_start = len(binarize(prefix))
mask_size = len(binarize(leading_space + txt))
mask[mask_start:mask_start + mask_size] = 1
return toks, mask
if query is not None:
query_toks, query_mask = binarize_with_mask(query)
query_len = len(query_toks)
else:
query_toks, query_mask, query_len = None, None, 0
query_tokens.append(query_toks)
query_masks.append(query_mask)
query_lengths.append(query_len)
cand_toks, cand_masks = [], []
for cand_span in cand_spans:
toks, mask = binarize_with_mask(cand_span.text)
cand_toks.append(toks)
cand_masks.append(mask)
# collate candidates
cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad())
cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0)
assert cand_toks.size() == cand_masks.size()
candidate_tokens.append(cand_toks)
candidate_masks.append(cand_masks)
candidate_lengths.append(cand_toks.size(1))
labels.append(label)
query_lengths = np.array(query_lengths)
query_tokens = ListDataset(query_tokens, query_lengths)
query_masks = ListDataset(query_masks, query_lengths)
candidate_lengths = np.array(candidate_lengths)
candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
candidate_masks = ListDataset(candidate_masks, candidate_lengths)
labels = ListDataset(labels, [1]*len(labels))
dataset = {
'id': IdDataset(),
'query_tokens': query_tokens,
'query_masks': query_masks,
'candidate_tokens': candidate_tokens,
'candidate_masks': candidate_masks,
'labels': labels,
'nsentences': NumSamplesDataset(),
'ntokens': NumelDataset(query_tokens, reduce=True),
}
nested_dataset = NestedDictionaryDataset(
dataset,
sizes=[query_lengths],
)
with data_utils.numpy_seed(self.args.seed):
shuffle = np.random.permutation(len(query_tokens))
dataset = SortDataset(
nested_dataset,
# shuffle
sort_order=[shuffle],
)
if return_only:
return dataset
self.datasets[split] = dataset
return self.datasets[split]
def build_dataset_for_inference(self, sample_json):
with tempfile.NamedTemporaryFile(buffering=0) as h:
h.write((json.dumps(sample_json) + '\n').encode('utf-8'))
dataset = self.load_dataset(
'disambiguate_pronoun',
data_path=h.name,
return_only=True,
)
return dataset
def disambiguate_pronoun(self, model, sentence, use_cuda=False):
sample_json = wsc_utils.convert_sentence_to_json(sentence)
dataset = self.build_dataset_for_inference(sample_json)
sample = dataset.collater([dataset[0]])
if use_cuda:
sample = utils.move_to_cuda(sample)
def get_masked_input(tokens, mask):
masked_tokens = tokens.clone()
masked_tokens[mask] = self.mask
return masked_tokens
def get_lprobs(tokens, mask):
logits, _ = model(src_tokens=get_masked_input(tokens, mask))
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
mask = mask.type_as(scores)
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
return scores
cand_lprobs = get_lprobs(
sample['candidate_tokens'][0],
sample['candidate_masks'][0],
)
if sample['query_tokens'][0] is not None:
query_lprobs = get_lprobs(
sample['query_tokens'][0].unsqueeze(0),
sample['query_masks'][0].unsqueeze(0),
)
return (query_lprobs >= cand_lprobs).all().item() == 1
else:
best_idx = cand_lprobs.argmax().item()
full_cand = sample['candidate_tokens'][0][best_idx]
mask = sample['candidate_masks'][0][best_idx]
toks = full_cand[mask]
return self.bpe.decode(self.source_dictionary.string(toks)).strip()
@property
def source_dictionary(self):
return self.vocab
@property
def target_dictionary(self):
return self.vocab
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import lru_cache
import json
def convert_sentence_to_json(sentence):
if '_' in sentence:
prefix, rest = sentence.split('_', 1)
query, rest = rest.split('_', 1)
query_index = len(prefix.rstrip().split(' '))
else:
query, query_index = None, None
prefix, rest = sentence.split('[', 1)
pronoun, rest = rest.split(']', 1)
pronoun_index = len(prefix.rstrip().split(' '))
sentence = sentence.replace('_', '').replace('[', '').replace(']', '')
return {
'idx': 0,
'text': sentence,
'target': {
'span1_index': query_index,
'span1_text': query,
'span2_index': pronoun_index,
'span2_text': pronoun,
},
}
def extended_noun_chunks(sentence):
noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks}
np_start, cur_np = 0, 'NONE'
for i, token in enumerate(sentence):
np_type = token.pos_ if token.pos_ in {'NOUN', 'PROPN'} else 'NONE'
if np_type != cur_np:
if cur_np != 'NONE':
noun_chunks.add((np_start, i))
if np_type != 'NONE':
np_start = i
cur_np = np_type
if cur_np != 'NONE':
noun_chunks.add((np_start, len(sentence)))
return [sentence[s:e] for (s, e) in sorted(noun_chunks)]
def find_token(sentence, start_pos):
found_tok = None
for tok in sentence:
if tok.idx == start_pos:
found_tok = tok
break
return found_tok
def find_span(sentence, search_text, start=0):
search_text = search_text.lower()
for tok in sentence[start:]:
remainder = sentence[tok.i:].text.lower()
if remainder.startswith(search_text):
len_to_consume = len(search_text)
start_idx = tok.idx
for next_tok in sentence[tok.i:]:
end_idx = next_tok.idx + len(next_tok.text)
if end_idx - start_idx == len_to_consume:
span = sentence[tok.i:next_tok.i + 1]
return span
return None
@lru_cache(maxsize=1)
def get_detokenizer():
from sacremoses import MosesDetokenizer
detok = MosesDetokenizer(lang='en')
return detok
@lru_cache(maxsize=1)
def get_spacy_nlp():
import en_core_web_lg
nlp = en_core_web_lg.load()
return nlp
def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
detok = get_detokenizer()
nlp = get_spacy_nlp()
with open(input_fname) as fin:
for line in fin:
sample = json.loads(line.strip())
if positive_only and 'label' in sample and not sample['label']:
# only consider examples where the query is correct
continue
target = sample['target']
# clean up the query
query = target['span1_text']
if query is not None:
if '\n' in query:
continue
if query.endswith('.') or query.endswith(','):
query = query[:-1]
# split tokens
tokens = sample['text'].split(' ')
def strip_pronoun(x):
return x.rstrip('.,"')
# find the pronoun
pronoun_idx = target['span2_index']
pronoun = strip_pronoun(target['span2_text'])
if strip_pronoun(tokens[pronoun_idx]) != pronoun:
# hack: sometimes the index is misaligned
if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun:
pronoun_idx += 1
else:
raise Exception('Misaligned pronoun!')
assert strip_pronoun(tokens[pronoun_idx]) == pronoun
# split tokens before and after the pronoun
before = tokens[:pronoun_idx]
after = tokens[pronoun_idx + 1:]
# the GPT BPE attaches leading spaces to tokens, so we keep track
# of whether we need spaces before or after the pronoun
leading_space = ' ' if pronoun_idx > 0 else ''
trailing_space = ' ' if len(after) > 0 else ''
# detokenize
before = detok.detokenize(before, return_str=True)
pronoun = detok.detokenize([pronoun], return_str=True)
after = detok.detokenize(after, return_str=True)
# hack: when the pronoun ends in a period (or comma), move the
# punctuation to the "after" part
if pronoun.endswith('.') or pronoun.endswith(','):
after = pronoun[-1] + trailing_space + after
pronoun = pronoun[:-1]
# hack: when the "after" part begins with a comma or period, remove
# the trailing space
if after.startswith('.') or after.startswith(','):
trailing_space = ''
# parse sentence with spacy
sentence = nlp(before + leading_space + pronoun + trailing_space + after)
# find pronoun span
start = len(before + leading_space)
first_pronoun_tok = find_token(sentence, start_pos=start)
pronoun_span = find_span(sentence, pronoun, start=first_pronoun_tok.i)
assert pronoun_span.text == pronoun
if eval:
# convert to format where pronoun is surrounded by "[]" and
# query is surrounded by "_"
query_span = find_span(sentence, query)
query_with_ws = '_{}_{}'.format(
query_span.text,
(' ' if query_span.text_with_ws.endswith(' ') else '')
)
pronoun_with_ws = '[{}]{}'.format(
pronoun_span.text,
(' ' if pronoun_span.text_with_ws.endswith(' ') else '')
)
if query_span.start < pronoun_span.start:
first = (query_span, query_with_ws)
second = (pronoun_span, pronoun_with_ws)
else:
first = (pronoun_span, pronoun_with_ws)
second = (query_span, query_with_ws)
sentence = (
sentence[:first[0].start].text_with_ws
+ first[1]
+ sentence[first[0].end:second[0].start].text_with_ws
+ second[1]
+ sentence[second[0].end:].text
)
yield sentence, sample.get('label', None)
else:
yield sentence, pronoun_span, query, sample.get('label', None)
def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact_match=False):
if exclude_pronouns:
chunks = [
np for np in chunks if (
np.lemma_ != '-PRON-'
and not all(tok.pos_ == 'PRON' for tok in np)
)
]
if exclude_query is not None:
excl_txt = [exclude_query.lower()]
filtered_chunks = []
for chunk in chunks:
lower_chunk = chunk.text.lower()
found = False
for excl in excl_txt:
if (
(not exact_match and (lower_chunk in excl or excl in lower_chunk))
or lower_chunk == excl
):
found = True
break
if not found:
filtered_chunks.append(chunk)
chunks = filtered_chunks
return chunks
...@@ -345,6 +345,8 @@ def load_pretrained_component_from_model( ...@@ -345,6 +345,8 @@ def load_pretrained_component_from_model(
def verify_checkpoint_directory(save_dir: str) -> None: def verify_checkpoint_directory(save_dir: str) -> None:
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
temp_file_path = os.path.join(save_dir, 'dummy') temp_file_path = os.path.join(save_dir, 'dummy')
try: try:
with open(temp_file_path, 'w'): with open(temp_file_path, 'w'):
......
...@@ -16,6 +16,7 @@ from .concat_sentences_dataset import ConcatSentencesDataset ...@@ -16,6 +16,7 @@ from .concat_sentences_dataset import ConcatSentencesDataset
from .id_dataset import IdDataset from .id_dataset import IdDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .list_dataset import ListDataset
from .lm_context_window_dataset import LMContextWindowDataset from .lm_context_window_dataset import LMContextWindowDataset
from .lru_cache_dataset import LRUCacheDataset from .lru_cache_dataset import LRUCacheDataset
from .mask_tokens_dataset import MaskTokensDataset from .mask_tokens_dataset import MaskTokensDataset
...@@ -59,6 +60,7 @@ __all__ = [ ...@@ -59,6 +60,7 @@ __all__ = [
'IndexedRawTextDataset', 'IndexedRawTextDataset',
'LanguagePairDataset', 'LanguagePairDataset',
'LeftPadDataset', 'LeftPadDataset',
'ListDataset',
'LMContextWindowDataset', 'LMContextWindowDataset',
'LRUCacheDataset', 'LRUCacheDataset',
'MaskTokensDataset', 'MaskTokensDataset',
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import BaseWrapperDataset
class ListDataset(BaseWrapperDataset):
def __init__(self, dataset, sizes):
super().__init__(dataset)
self._sizes = sizes
def collater(self, samples):
return samples
@property
def sizes(self):
return self._sizes
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
def set_epoch(self, epoch):
pass
...@@ -47,6 +47,9 @@ def from_pretrained( ...@@ -47,6 +47,9 @@ def from_pretrained(
if os.path.exists(path): if os.path.exists(path):
kwargs[arg] = path kwargs[arg] = path
if 'user_dir' in kwargs:
utils.import_user_module(argparse.Namespace(user_dir=kwargs['user_dir']))
models, args, task = checkpoint_utils.load_model_ensemble_and_task( models, args, task = checkpoint_utils.load_model_ensemble_and_task(
[os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')], [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')],
arg_overrides=kwargs, arg_overrides=kwargs,
......
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils
from fairseq.data import encoders from fairseq.data import encoders
...@@ -152,6 +153,7 @@ class RobertaHubInterface(nn.Module): ...@@ -152,6 +153,7 @@ class RobertaHubInterface(nn.Module):
if tokens.dim() == 1: if tokens.dim() == 1:
tokens = tokens.unsqueeze(0) tokens = tokens.unsqueeze(0)
with utils.eval(self.model):
features, extra = self.model( features, extra = self.model(
tokens.long().to(device=self.device), tokens.long().to(device=self.device),
features_only=False, features_only=False,
...@@ -178,3 +180,18 @@ class RobertaHubInterface(nn.Module): ...@@ -178,3 +180,18 @@ class RobertaHubInterface(nn.Module):
values[index].item(), values[index].item(),
)) ))
return topk_filled_outputs return topk_filled_outputs
def disambiguate_pronoun(self, sentence: str) -> bool:
"""
Usage::
>>> disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.')
True
>>> disambiguate_pronoun('The trophy would not fit in the brown suitcase because [it] was too big.')
'The trophy'
"""
assert hasattr(self.task, 'disambiguate_pronoun'), \
'roberta.disambiguate_pronoun() requires a model trained with the WSC task.'
with utils.eval(self.model):
return self.task.disambiguate_pronoun(self.model, sentence, use_cuda=self.device.type == 'cuda')
...@@ -35,6 +35,7 @@ class RobertaModel(FairseqLanguageModel): ...@@ -35,6 +35,7 @@ class RobertaModel(FairseqLanguageModel):
'roberta.base': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz', 'roberta.base': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz',
'roberta.large': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz', 'roberta.large': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz',
'roberta.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz', 'roberta.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz',
'roberta.large.wsc': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz',
} }
def __init__(self, args, encoder): def __init__(self, args, encoder):
......
...@@ -14,8 +14,6 @@ import os ...@@ -14,8 +14,6 @@ import os
import re import re
import sys import sys
from tqdm import tqdm
from fairseq import distributed_utils from fairseq import distributed_utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
...@@ -208,6 +206,7 @@ class tqdm_progress_bar(progress_bar): ...@@ -208,6 +206,7 @@ class tqdm_progress_bar(progress_bar):
def __init__(self, iterable, epoch=None, prefix=None): def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix) super().__init__(iterable, epoch, prefix)
from tqdm import tqdm
self.tqdm = tqdm(iterable, self.prefix, leave=False) self.tqdm = tqdm(iterable, self.prefix, leave=False)
def __iter__(self): def __iter__(self):
......
...@@ -104,6 +104,7 @@ class MaskedLMTask(FairseqTask): ...@@ -104,6 +104,7 @@ class MaskedLMTask(FairseqTask):
eos=self.source_dictionary.eos(), eos=self.source_dictionary.eos(),
break_mode=self.args.sample_break_mode, break_mode=self.args.sample_break_mode,
) )
print('| loaded {} batches from: {}'.format(len(dataset), split_path))
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
...@@ -210,14 +211,3 @@ class MaskedLMTask(FairseqTask): ...@@ -210,14 +211,3 @@ class MaskedLMTask(FairseqTask):
@property @property
def target_dictionary(self): def target_dictionary(self):
return self.dictionary return self.dictionary
def get_average_masked_score(self, model, src_tokens, mask, **net_input):
"""Mask a set of tokens and return their average score."""
masked_tokens = src_tokens.clone()
masked_tokens[mask.byte()] = self.mask_idx
net_output = model(src_tokens=masked_tokens, **net_input, last_state_only=True)
lprobs = F.log_softmax(net_output[0], dim=-1, dtype=torch.float32)
lprobs = lprobs.gather(-1, src_tokens.unsqueeze(-1)).squeeze(-1)
mask = mask.type_as(lprobs)
score = (lprobs * mask).sum(dim=-1) / mask.sum(dim=-1)
return score
...@@ -12,14 +12,6 @@ from fairseq.tasks import register_task ...@@ -12,14 +12,6 @@ from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask from fairseq.tasks.translation import TranslationTask
@contextlib.contextmanager
def eval(model):
is_training = model.training
model.eval()
yield
model.train(is_training)
@register_task('translation_moe') @register_task('translation_moe')
class TranslationMoETask(TranslationTask): class TranslationMoETask(TranslationTask):
""" """
...@@ -163,7 +155,7 @@ class TranslationMoETask(TranslationTask): ...@@ -163,7 +155,7 @@ class TranslationMoETask(TranslationTask):
return lprob_yz return lprob_yz
# compute responsibilities without dropout # compute responsibilities without dropout
with eval(model): # disable dropout with utils.eval(model): # disable dropout
with torch.no_grad(): # disable autograd with torch.no_grad(): # disable autograd
lprob_yz = get_lprob_yz() # B x K lprob_yz = get_lprob_yz() # B x K
prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1) prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1)
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import defaultdict from collections import defaultdict
import contextlib
import copy import copy
import importlib.util import importlib.util
import math import math
...@@ -277,6 +278,10 @@ def import_user_module(args): ...@@ -277,6 +278,10 @@ def import_user_module(args):
module_path = getattr(args, 'user_dir', None) module_path = getattr(args, 'user_dir', None)
if module_path is not None: if module_path is not None:
module_path = os.path.abspath(args.user_dir) module_path = os.path.abspath(args.user_dir)
if not os.path.exists(module_path):
fairseq_rel_path = os.path.join(os.path.dirname(__file__), '..', args.user_dir)
if os.path.exists(fairseq_rel_path):
module_path = fairseq_rel_path
module_parent, module_name = os.path.split(module_path) module_parent, module_name = os.path.split(module_path)
if module_name not in sys.modules: if module_name not in sys.modules:
...@@ -339,3 +344,11 @@ def get_available_activation_fns() -> List: ...@@ -339,3 +344,11 @@ def get_available_activation_fns() -> List:
'tanh', 'tanh',
'linear', 'linear',
] ]
@contextlib.contextmanager
def eval(model):
is_training = model.training
model.eval()
yield
model.train(is_training)
...@@ -21,7 +21,7 @@ for _model_type, _cls in MODEL_REGISTRY.items(): ...@@ -21,7 +21,7 @@ for _model_type, _cls in MODEL_REGISTRY.items():
for model_name in _cls.hub_models().keys(): for model_name in _cls.hub_models().keys():
globals()[model_name] = functools.partial( globals()[model_name] = functools.partial(
_cls.from_pretrained, _cls.from_pretrained,
model_name_or_path=model_name, model_name,
) )
# to simplify the interface we only expose named models # to simplify the interface we only expose named models
# globals()[_model_type] = _cls.from_pretrained # globals()[_model_type] = _cls.from_pretrained
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment