Commit 83249196 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add WSC task and criterion

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/1004

Differential Revision: D16751443

Pulled By: myleott

fbshipit-source-id: f70acd6c7be6d69da45b5b32fe4c4eff021539ab
parent a00ce132
......@@ -12,7 +12,8 @@ Model | Description | # params | Download
---|---|---|---
`roberta.base` | RoBERTa using the BERT-base architecture | 125M | [roberta.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz)
`roberta.large` | RoBERTa using the BERT-large architecture | 355M | [roberta.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz)
`roberta.large.mnli` | `roberta.large` finetuned on MNLI | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz)
`roberta.large.mnli` | `roberta.large` finetuned on [MNLI](http://www.nyu.edu/projects/bowman/multinli) | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz)
`roberta.large.wsc` | `roberta.large` finetuned on [WSC](https://cs.nyu.edu/faculty/davise/papers/WinogradSchemas/WS.html) | 355M | [roberta.large.wsc.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz)
## Results
......@@ -24,12 +25,12 @@ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
`roberta.large.mnli` | 90.2 | - | - | - | - | - | - | -
##### Results on SuperGLUE tasks (dev set, single model, single-task finetuning)
Model | BoolQ | CB | COPA | MultiRC | RTE | WiC | WSC
---|---|---|---|---|---|---|---
`roberta.large` | 86.9 | 98.2 | 94.0 | 85.7 | 89.5 | 75.6 | 91.3
`roberta.large` | 86.9 | 98.2 | 94.0 | 85.7 | 89.5 | 75.6 | -
`roberta.large.wsc` | - | - | - | - | - | - | 91.3
##### Results on SQuAD (dev set)
......@@ -83,28 +84,6 @@ assert len(all_layers) == 25
assert torch.all(all_layers[-1] == last_layer_features)
```
By default RoBERTa outputs one feature vector per BPE token. You can instead
realign the features to match [spaCy's word-level tokenization](https://spacy.io/usage/linguistic-features#tokenization)
with the `extract_features_aligned_to_words` method. This will compute a
weighted average of the BPE-level features for each word and expose them in
spaCy's `Token.vector` attribute:
```python
doc = roberta.extract_features_aligned_to_words('I said, "hello RoBERTa."')
assert len(doc) == 10
for tok in doc:
print('{:10}{} (...)'.format(str(tok), tok.vector[:5]))
# <s> tensor([-0.1316, -0.0386, -0.0832, -0.0477, 0.1943], grad_fn=<SliceBackward>) (...)
# I tensor([ 0.0559, 0.1541, -0.4832, 0.0880, 0.0120], grad_fn=<SliceBackward>) (...)
# said tensor([-0.1565, -0.0069, -0.8915, 0.0501, -0.0647], grad_fn=<SliceBackward>) (...)
# , tensor([-0.1318, -0.0387, -0.0834, -0.0477, 0.1944], grad_fn=<SliceBackward>) (...)
# " tensor([-0.0486, 0.1818, -0.3946, -0.0553, 0.0981], grad_fn=<SliceBackward>) (...)
# hello tensor([ 0.0079, 0.1799, -0.6204, -0.0777, -0.0923], grad_fn=<SliceBackward>) (...)
# RoBERTa tensor([-0.2339, -0.1184, -0.7343, -0.0492, 0.5829], grad_fn=<SliceBackward>) (...)
# . tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=<SliceBackward>) (...)
# " tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=<SliceBackward>) (...)
# </s> tensor([-0.0930, -0.0392, -0.0821, 0.0158, 0.0649], grad_fn=<SliceBackward>) (...)
```
##### Use RoBERTa for sentence-pair classification tasks:
```python
# Download RoBERTa already finetuned for MNLI
......@@ -141,22 +120,79 @@ roberta.cuda()
roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)
```
##### Filling mask:
Some examples from the [Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/).
## Advanced usage
#### Filling masks:
RoBERTa can be used to fill `<mask>` tokens in the input. Some examples from the
[Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/):
```python
>>> roberta.fill_mask("The first Star wars movie came out in <mask>", topk=3)
[('The first Star wars movie came out in 1977', 0.9504712224006653), ('The first Star wars movie came out in 1978', 0.009986752644181252), ('The first Star wars movie came out in 1979', 0.00957468245178461)]
roberta.fill_mask('The first Star wars movie came out in <mask>', topk=3)
# [('The first Star wars movie came out in 1977', 0.9504712224006653), ('The first Star wars movie came out in 1978', 0.009986752644181252), ('The first Star wars movie came out in 1979', 0.00957468245178461)]
roberta.fill_mask('Vikram samvat calender is official in <mask>', topk=3)
# [('Vikram samvat calender is official in India', 0.21878768503665924), ('Vikram samvat calender is official in Delhi', 0.08547217398881912), ('Vikram samvat calender is official in Gujarat', 0.07556255906820297)]
roberta.fill_mask('<mask> is the common currency of the European Union', topk=3)
# [('Euro is the common currency of the European Union', 0.945650577545166), ('euro is the common currency of the European Union', 0.025747718289494514), ('€ is the common currency of the European Union', 0.011183015070855618)]
```
>>> roberta.fill_mask("Vikram samvat calender is official in <mask>", topk=3)
[('Vikram samvat calender is official in India', 0.21878768503665924), ('Vikram samvat calender is official in Delhi', 0.08547217398881912), ('Vikram samvat calender is official in Gujarat', 0.07556255906820297)]
#### Pronoun disambiguation (Winograd Schema Challenge):
>>> roberta.fill_mask("<mask> is the common currency of the European Union", topk=3)
[('Euro is the common currency of the European Union', 0.945650577545166), ('euro is the common currency of the European Union', 0.025747718289494514), ('€ is the common currency of the European Union', 0.011183015070855618)]
RoBERTa can be used to disambiguate pronouns. First install spaCy and download the English-language model:
```bash
pip install spacy
python -m spacy download en_core_web_lg
```
Next load the `roberta.large.wsc` model and call the `disambiguate_pronoun`
function. The pronoun should be surrounded by square brackets (`[]`) and the
query referent surrounded by underscores (`_`), or left blank to return the
predicted candidate text directly:
```python
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.wsc', user_dir='examples/roberta/wsc')
roberta.cuda() # use the GPU (optional)
roberta.disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.')
# True
roberta.disambiguate_pronoun('The trophy would not fit in the brown _suitcase_ because [it] was too big.')
# False
roberta.disambiguate_pronoun('The city councilmen refused the demonstrators a permit because [they] feared violence.')
# 'The city councilmen'
roberta.disambiguate_pronoun('The city councilmen refused the demonstrators a permit because [they] advocated violence.')
# 'demonstrators'
```
See the [RoBERTA Winograd Schema Challenge (WSC) README](README.wsc.md) for more details on how to train this model.
#### Extract features aligned to words:
By default RoBERTa outputs one feature vector per BPE token. You can instead
realign the features to match [spaCy's word-level tokenization](https://spacy.io/usage/linguistic-features#tokenization)
with the `extract_features_aligned_to_words` method. This will compute a
weighted average of the BPE-level features for each word and expose them in
spaCy's `Token.vector` attribute:
```python
doc = roberta.extract_features_aligned_to_words('I said, "hello RoBERTa."')
assert len(doc) == 10
for tok in doc:
print('{:10}{} (...)'.format(str(tok), tok.vector[:5]))
# <s> tensor([-0.1316, -0.0386, -0.0832, -0.0477, 0.1943], grad_fn=<SliceBackward>) (...)
# I tensor([ 0.0559, 0.1541, -0.4832, 0.0880, 0.0120], grad_fn=<SliceBackward>) (...)
# said tensor([-0.1565, -0.0069, -0.8915, 0.0501, -0.0647], grad_fn=<SliceBackward>) (...)
# , tensor([-0.1318, -0.0387, -0.0834, -0.0477, 0.1944], grad_fn=<SliceBackward>) (...)
# " tensor([-0.0486, 0.1818, -0.3946, -0.0553, 0.0981], grad_fn=<SliceBackward>) (...)
# hello tensor([ 0.0079, 0.1799, -0.6204, -0.0777, -0.0923], grad_fn=<SliceBackward>) (...)
# RoBERTa tensor([-0.2339, -0.1184, -0.7343, -0.0492, 0.5829], grad_fn=<SliceBackward>) (...)
# . tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=<SliceBackward>) (...)
# " tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=<SliceBackward>) (...)
# </s> tensor([-0.0930, -0.0392, -0.0821, 0.0158, 0.0649], grad_fn=<SliceBackward>) (...)
```
##### Evaluating the `roberta.large.mnli` model
#### Evaluating the `roberta.large.mnli` model:
Example python code snippet to evaluate accuracy on the MNLI dev_matched set.
Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
```python
label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
ncorrect, nsamples = 0, 0
......@@ -181,6 +217,7 @@ print('| Accuracy: ', float(ncorrect)/float(nsamples))
- [Finetuning on GLUE](README.finetune_glue.md)
- [Finetuning on custom classification tasks (e.g., IMDB)](README.finetune_custom_classification.md)
- [Finetuning on Winograd Schema Challenge (WSC)](README.wsc.md)
- Finetuning on SQuAD: coming soon
## Pretraining using your own data
......
# Finetuning RoBERTa on Winograd Schema Challenge (WSC) data
The following instructions can be used to finetune RoBERTa on the WSC training
data provided by [SuperGLUE](https://super.gluebenchmark.com/).
Note that there is high variance in the results. For our GLUE/SuperGLUE
submission we swept over the learning rate, batch size and total number of
updates, as well as the random seed. Out of ~100 runs we chose the best 7 models
and ensembled them.
**Note:** The instructions below use a slightly different loss function than
what's described in the original RoBERTa arXiv paper. In particular,
[Kocijan et al. (2019)](https://arxiv.org/abs/1905.06290) introduce a margin
ranking loss between `(query, candidate)` pairs with tunable hyperparameters
alpha and beta. This is supported in our code as well with the `--wsc-alpha` and
`--wsc-beta` arguments. However, we achieved slightly better (and more robust)
results on the development set by instead using a single cross entropy loss term
over the log-probabilities for the query and all candidates. This reduces the
number of hyperparameters and our best model achieved 92.3% development set
accuracy, compared to ~90% accuracy for the margin loss. Later versions of the
RoBERTa arXiv paper will describe this updated formulation.
### 1) Download the WSC data from the SuperGLUE website:
```bash
wget https://dl.fbaipublicfiles.com/glue/superglue/data/v2/WSC.zip
unzip WSC.zip
# we also need to copy the RoBERTa dictionary into the same directory
wget -O WSC/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
```
### 2) Finetune over the provided training data:
```bash
TOTAL_NUM_UPDATES=2000 # Total number of training steps.
WARMUP_UPDATES=250 # Linearly increase LR over this many steps.
LR=2e-05 # Peak LR for polynomial LR scheduler.
MAX_SENTENCES=16 # Batch size per GPU.
SEED=1 # Random seed.
ROBERTA_PATH=/path/to/roberta/model.pt
# we use the --user-dir option to load the task and criterion
# from the examples/roberta/wsc directory:
FAIRSEQ_PATH=/path/to/fairseq
FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/wsc
cd fairseq
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train WSC/ \
--restore-file $ROBERTA_PATH \
--reset-optimizer --reset-dataloader --reset-meters \
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
--valid-subset val \
--fp16 --ddp-backend no_c10d \
--user-dir $FAIRSEQ_USER_DIR \
--task wsc --criterion wsc --wsc-cross-entropy \
--arch roberta_large --bpe gpt2 --max-positions 512 \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
--lr-scheduler polynomial_decay --lr $LR \
--warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \
--max-sentences $MAX_SENTENCES \
--max-update $TOTAL_NUM_UPDATES \
--log-format simple --log-interval 100
```
The above command assumes training on 4 GPUs, but you can achieve the same
results on a single GPU by adding `--update-freq=4`.
### 3) Evaluate
```python
from fairseq.models.roberta import RobertaModel
from examples.roberta.wsc import wsc_utils # also loads WSC task and criterion
roberta = RobertaModel.from_pretrained('checkpoints', 'checkpoint_best.pt', 'WSC/')
roberta.cuda()
nsamples, ncorrect = 0, 0
for sentence, label in wsc_utils.jsonl_iterator('WSC/val.jsonl', eval=True):
pred = roberta.disambiguate_pronoun(sentence)
nsamples += 1
if pred == label:
ncorrect += 1
print('Accuracy: ' + str(ncorrect / float(nsamples)))
# Accuracy: 0.9230769230769231
```
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import wsc_criterion # noqa
from . import wsc_task # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import encoders
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion('wsc')
class WSCCriterion(FairseqCriterion):
def __init__(self, args, task):
super().__init__(args, task)
if self.args.save_predictions is not None:
self.prediction_h = open(self.args.save_predictions, 'w')
else:
self.prediction_h = None
self.bpe = encoders.build_bpe(args)
self.tokenizer = encoders.build_tokenizer(args)
def __del__(self):
if self.prediction_h is not None:
self.prediction_h.close()
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
parser.add_argument('--wsc-margin-alpha', type=float, metavar='A', default=1.0)
parser.add_argument('--wsc-margin-beta', type=float, metavar='B', default=0.0)
parser.add_argument('--wsc-cross-entropy', action='store_true',
help='use cross entropy formulation instead of margin loss')
parser.add_argument('--save-predictions', metavar='FILE',
help='file to save predictions to')
def forward(self, model, sample, reduce=True):
def get_masked_input(tokens, mask):
masked_tokens = tokens.clone()
masked_tokens[mask] = self.task.mask
return masked_tokens
def get_lprobs(tokens, mask):
logits, _ = model(src_tokens=get_masked_input(tokens, mask))
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
mask = mask.type_as(scores)
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
return scores
# compute loss and accuracy
loss, nloss = 0., 0
ncorrect, nqueries = 0, 0
for i, label in enumerate(sample['labels']):
query_lprobs = get_lprobs(
sample['query_tokens'][i].unsqueeze(0),
sample['query_masks'][i].unsqueeze(0),
)
cand_lprobs = get_lprobs(
sample['candidate_tokens'][i],
sample['candidate_masks'][i],
)
pred = (query_lprobs >= cand_lprobs).all().item()
if label is not None:
label = 1 if label else 0
ncorrect += 1 if pred == label else 0
nqueries += 1
if label:
# only compute a loss for positive instances
nloss += 1
if self.args.wsc_cross_entropy:
loss += F.cross_entropy(
torch.cat([query_lprobs, cand_lprobs]).unsqueeze(0),
query_lprobs.new([0]).long(),
)
else:
loss += (
- query_lprobs
+ self.args.wsc_margin_alpha * (
cand_lprobs - query_lprobs + self.args.wsc_margin_beta
).clamp(min=0)
).sum()
id = sample['id'][i].item()
if self.prediction_h is not None:
print('{}\t{}\t{}'.format(id, pred, label), file=self.prediction_h)
if nloss == 0:
loss = torch.tensor(0.0, requires_grad=True)
sample_size = nqueries if nqueries > 0 else 1
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['nsentences'],
'sample_size': sample_size,
'ncorrect': ncorrect,
'nqueries': nqueries,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs)
nqueries = sum(log.get('nqueries', 0) for log in logging_outputs)
if nqueries > 0:
agg_output['accuracy'] = ncorrect / float(nqueries)
return agg_output
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import tempfile
import numpy as np
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import (
data_utils,
Dictionary,
encoders,
IdDataset,
ListDataset,
NestedDictionaryDataset,
NumSamplesDataset,
NumelDataset,
SortDataset,
)
from fairseq.tasks import FairseqTask, register_task
from . import wsc_utils
@register_task('wsc')
class WSCTask(FairseqTask):
"""Task to finetune RoBERTa for Winograd Schemas."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR',
help='path to data directory; we load <split>.jsonl')
parser.add_argument('--init-token', type=int, default=None,
help='add token at the beginning of each batch item')
def __init__(self, args, vocab):
super().__init__(args)
self.vocab = vocab
self.mask = vocab.add_symbol('<mask>')
self.bpe = encoders.build_bpe(args)
self.tokenizer = encoders.build_tokenizer(args)
# hack to handle GPT-2 BPE, which includes leading spaces
if args.bpe == 'gpt2':
self.leading_space = True
self.trailing_space = False
else:
self.leading_space = False
self.trailing_space = True
@classmethod
def load_dictionary(cls, filename):
"""Load the dictionary from the filename
Args:
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
dictionary.add_symbol('<mask>')
return dictionary
@classmethod
def setup_task(cls, args, **kwargs):
assert args.criterion == 'wsc', 'Must set --criterion=wsc'
# load data and label dictionaries
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(vocab)))
return cls(args, vocab)
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
def binarize(s: str, append_eos: bool = False):
if self.tokenizer is not None:
s = self.tokenizer.encode(s)
if self.bpe is not None:
s = self.bpe.encode(s)
tokens = self.vocab.encode_line(
s, append_eos=append_eos, add_if_not_exist=False,
).long()
if self.args.init_token is not None:
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
return tokens
if data_path is None:
data_path = os.path.join(self.args.data, split + '.jsonl')
if not os.path.exists(data_path):
raise FileNotFoundError('Cannot find data: {}'.format(data_path))
query_tokens = []
query_masks = []
query_lengths = []
candidate_tokens = []
candidate_masks = []
candidate_lengths = []
labels = []
for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path):
prefix = sentence[:pronoun_span.start].text
suffix = sentence[pronoun_span.end:].text_with_ws
# spaCy spans include trailing spaces, but we need to know about
# leading spaces for the GPT-2 BPE
leading_space = ' ' if sentence[:pronoun_span.start].text_with_ws.endswith(' ') else ''
trailing_space = ' ' if pronoun_span.text_with_ws.endswith(' ') else ''
# get noun phrases, excluding pronouns and anything overlapping with the query
cand_spans = wsc_utils.filter_noun_chunks(
wsc_utils.extended_noun_chunks(sentence),
exclude_pronouns=True,
exclude_query=query,
exact_match=False,
)
def binarize_with_mask(txt):
toks = binarize(
prefix + leading_space + txt + trailing_space + suffix,
append_eos=True,
)
mask = torch.zeros_like(toks, dtype=torch.uint8)
mask_start = len(binarize(prefix))
mask_size = len(binarize(leading_space + txt))
mask[mask_start:mask_start + mask_size] = 1
return toks, mask
if query is not None:
query_toks, query_mask = binarize_with_mask(query)
query_len = len(query_toks)
else:
query_toks, query_mask, query_len = None, None, 0
query_tokens.append(query_toks)
query_masks.append(query_mask)
query_lengths.append(query_len)
cand_toks, cand_masks = [], []
for cand_span in cand_spans:
toks, mask = binarize_with_mask(cand_span.text)
cand_toks.append(toks)
cand_masks.append(mask)
# collate candidates
cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad())
cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0)
assert cand_toks.size() == cand_masks.size()
candidate_tokens.append(cand_toks)
candidate_masks.append(cand_masks)
candidate_lengths.append(cand_toks.size(1))
labels.append(label)
query_lengths = np.array(query_lengths)
query_tokens = ListDataset(query_tokens, query_lengths)
query_masks = ListDataset(query_masks, query_lengths)
candidate_lengths = np.array(candidate_lengths)
candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
candidate_masks = ListDataset(candidate_masks, candidate_lengths)
labels = ListDataset(labels, [1]*len(labels))
dataset = {
'id': IdDataset(),
'query_tokens': query_tokens,
'query_masks': query_masks,
'candidate_tokens': candidate_tokens,
'candidate_masks': candidate_masks,
'labels': labels,
'nsentences': NumSamplesDataset(),
'ntokens': NumelDataset(query_tokens, reduce=True),
}
nested_dataset = NestedDictionaryDataset(
dataset,
sizes=[query_lengths],
)
with data_utils.numpy_seed(self.args.seed):
shuffle = np.random.permutation(len(query_tokens))
dataset = SortDataset(
nested_dataset,
# shuffle
sort_order=[shuffle],
)
if return_only:
return dataset
self.datasets[split] = dataset
return self.datasets[split]
def build_dataset_for_inference(self, sample_json):
with tempfile.NamedTemporaryFile(buffering=0) as h:
h.write((json.dumps(sample_json) + '\n').encode('utf-8'))
dataset = self.load_dataset(
'disambiguate_pronoun',
data_path=h.name,
return_only=True,
)
return dataset
def disambiguate_pronoun(self, model, sentence, use_cuda=False):
sample_json = wsc_utils.convert_sentence_to_json(sentence)
dataset = self.build_dataset_for_inference(sample_json)
sample = dataset.collater([dataset[0]])
if use_cuda:
sample = utils.move_to_cuda(sample)
def get_masked_input(tokens, mask):
masked_tokens = tokens.clone()
masked_tokens[mask] = self.mask
return masked_tokens
def get_lprobs(tokens, mask):
logits, _ = model(src_tokens=get_masked_input(tokens, mask))
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
mask = mask.type_as(scores)
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
return scores
cand_lprobs = get_lprobs(
sample['candidate_tokens'][0],
sample['candidate_masks'][0],
)
if sample['query_tokens'][0] is not None:
query_lprobs = get_lprobs(
sample['query_tokens'][0].unsqueeze(0),
sample['query_masks'][0].unsqueeze(0),
)
return (query_lprobs >= cand_lprobs).all().item() == 1
else:
best_idx = cand_lprobs.argmax().item()
full_cand = sample['candidate_tokens'][0][best_idx]
mask = sample['candidate_masks'][0][best_idx]
toks = full_cand[mask]
return self.bpe.decode(self.source_dictionary.string(toks)).strip()
@property
def source_dictionary(self):
return self.vocab
@property
def target_dictionary(self):
return self.vocab
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import lru_cache
import json
def convert_sentence_to_json(sentence):
if '_' in sentence:
prefix, rest = sentence.split('_', 1)
query, rest = rest.split('_', 1)
query_index = len(prefix.rstrip().split(' '))
else:
query, query_index = None, None
prefix, rest = sentence.split('[', 1)
pronoun, rest = rest.split(']', 1)
pronoun_index = len(prefix.rstrip().split(' '))
sentence = sentence.replace('_', '').replace('[', '').replace(']', '')
return {
'idx': 0,
'text': sentence,
'target': {
'span1_index': query_index,
'span1_text': query,
'span2_index': pronoun_index,
'span2_text': pronoun,
},
}
def extended_noun_chunks(sentence):
noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks}
np_start, cur_np = 0, 'NONE'
for i, token in enumerate(sentence):
np_type = token.pos_ if token.pos_ in {'NOUN', 'PROPN'} else 'NONE'
if np_type != cur_np:
if cur_np != 'NONE':
noun_chunks.add((np_start, i))
if np_type != 'NONE':
np_start = i
cur_np = np_type
if cur_np != 'NONE':
noun_chunks.add((np_start, len(sentence)))
return [sentence[s:e] for (s, e) in sorted(noun_chunks)]
def find_token(sentence, start_pos):
found_tok = None
for tok in sentence:
if tok.idx == start_pos:
found_tok = tok
break
return found_tok
def find_span(sentence, search_text, start=0):
search_text = search_text.lower()
for tok in sentence[start:]:
remainder = sentence[tok.i:].text.lower()
if remainder.startswith(search_text):
len_to_consume = len(search_text)
start_idx = tok.idx
for next_tok in sentence[tok.i:]:
end_idx = next_tok.idx + len(next_tok.text)
if end_idx - start_idx == len_to_consume:
span = sentence[tok.i:next_tok.i + 1]
return span
return None
@lru_cache(maxsize=1)
def get_detokenizer():
from sacremoses import MosesDetokenizer
detok = MosesDetokenizer(lang='en')
return detok
@lru_cache(maxsize=1)
def get_spacy_nlp():
import en_core_web_lg
nlp = en_core_web_lg.load()
return nlp
def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
detok = get_detokenizer()
nlp = get_spacy_nlp()
with open(input_fname) as fin:
for line in fin:
sample = json.loads(line.strip())
if positive_only and 'label' in sample and not sample['label']:
# only consider examples where the query is correct
continue
target = sample['target']
# clean up the query
query = target['span1_text']
if query is not None:
if '\n' in query:
continue
if query.endswith('.') or query.endswith(','):
query = query[:-1]
# split tokens
tokens = sample['text'].split(' ')
def strip_pronoun(x):
return x.rstrip('.,"')
# find the pronoun
pronoun_idx = target['span2_index']
pronoun = strip_pronoun(target['span2_text'])
if strip_pronoun(tokens[pronoun_idx]) != pronoun:
# hack: sometimes the index is misaligned
if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun:
pronoun_idx += 1
else:
raise Exception('Misaligned pronoun!')
assert strip_pronoun(tokens[pronoun_idx]) == pronoun
# split tokens before and after the pronoun
before = tokens[:pronoun_idx]
after = tokens[pronoun_idx + 1:]
# the GPT BPE attaches leading spaces to tokens, so we keep track
# of whether we need spaces before or after the pronoun
leading_space = ' ' if pronoun_idx > 0 else ''
trailing_space = ' ' if len(after) > 0 else ''
# detokenize
before = detok.detokenize(before, return_str=True)
pronoun = detok.detokenize([pronoun], return_str=True)
after = detok.detokenize(after, return_str=True)
# hack: when the pronoun ends in a period (or comma), move the
# punctuation to the "after" part
if pronoun.endswith('.') or pronoun.endswith(','):
after = pronoun[-1] + trailing_space + after
pronoun = pronoun[:-1]
# hack: when the "after" part begins with a comma or period, remove
# the trailing space
if after.startswith('.') or after.startswith(','):
trailing_space = ''
# parse sentence with spacy
sentence = nlp(before + leading_space + pronoun + trailing_space + after)
# find pronoun span
start = len(before + leading_space)
first_pronoun_tok = find_token(sentence, start_pos=start)
pronoun_span = find_span(sentence, pronoun, start=first_pronoun_tok.i)
assert pronoun_span.text == pronoun
if eval:
# convert to format where pronoun is surrounded by "[]" and
# query is surrounded by "_"
query_span = find_span(sentence, query)
query_with_ws = '_{}_{}'.format(
query_span.text,
(' ' if query_span.text_with_ws.endswith(' ') else '')
)
pronoun_with_ws = '[{}]{}'.format(
pronoun_span.text,
(' ' if pronoun_span.text_with_ws.endswith(' ') else '')
)
if query_span.start < pronoun_span.start:
first = (query_span, query_with_ws)
second = (pronoun_span, pronoun_with_ws)
else:
first = (pronoun_span, pronoun_with_ws)
second = (query_span, query_with_ws)
sentence = (
sentence[:first[0].start].text_with_ws
+ first[1]
+ sentence[first[0].end:second[0].start].text_with_ws
+ second[1]
+ sentence[second[0].end:].text
)
yield sentence, sample.get('label', None)
else:
yield sentence, pronoun_span, query, sample.get('label', None)
def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact_match=False):
if exclude_pronouns:
chunks = [
np for np in chunks if (
np.lemma_ != '-PRON-'
and not all(tok.pos_ == 'PRON' for tok in np)
)
]
if exclude_query is not None:
excl_txt = [exclude_query.lower()]
filtered_chunks = []
for chunk in chunks:
lower_chunk = chunk.text.lower()
found = False
for excl in excl_txt:
if (
(not exact_match and (lower_chunk in excl or excl in lower_chunk))
or lower_chunk == excl
):
found = True
break
if not found:
filtered_chunks.append(chunk)
chunks = filtered_chunks
return chunks
......@@ -345,6 +345,8 @@ def load_pretrained_component_from_model(
def verify_checkpoint_directory(save_dir: str) -> None:
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
temp_file_path = os.path.join(save_dir, 'dummy')
try:
with open(temp_file_path, 'w'):
......
......@@ -16,6 +16,7 @@ from .concat_sentences_dataset import ConcatSentencesDataset
from .id_dataset import IdDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset
from .language_pair_dataset import LanguagePairDataset
from .list_dataset import ListDataset
from .lm_context_window_dataset import LMContextWindowDataset
from .lru_cache_dataset import LRUCacheDataset
from .mask_tokens_dataset import MaskTokensDataset
......@@ -59,6 +60,7 @@ __all__ = [
'IndexedRawTextDataset',
'LanguagePairDataset',
'LeftPadDataset',
'ListDataset',
'LMContextWindowDataset',
'LRUCacheDataset',
'MaskTokensDataset',
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import BaseWrapperDataset
class ListDataset(BaseWrapperDataset):
def __init__(self, dataset, sizes):
super().__init__(dataset)
self._sizes = sizes
def collater(self, samples):
return samples
@property
def sizes(self):
return self._sizes
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
def set_epoch(self, epoch):
pass
......@@ -47,6 +47,9 @@ def from_pretrained(
if os.path.exists(path):
kwargs[arg] = path
if 'user_dir' in kwargs:
utils.import_user_module(argparse.Namespace(user_dir=kwargs['user_dir']))
models, args, task = checkpoint_utils.load_model_ensemble_and_task(
[os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')],
arg_overrides=kwargs,
......
......@@ -10,6 +10,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import encoders
......@@ -152,11 +153,12 @@ class RobertaHubInterface(nn.Module):
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
features, extra = self.model(
tokens.long().to(device=self.device),
features_only=False,
return_all_hiddens=False,
)
with utils.eval(self.model):
features, extra = self.model(
tokens.long().to(device=self.device),
features_only=False,
return_all_hiddens=False,
)
logits = features[0, masked_index, :].squeeze()
prob = logits.softmax(dim=0)
values, index = prob.topk(k=topk, dim=0)
......@@ -178,3 +180,18 @@ class RobertaHubInterface(nn.Module):
values[index].item(),
))
return topk_filled_outputs
def disambiguate_pronoun(self, sentence: str) -> bool:
"""
Usage::
>>> disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.')
True
>>> disambiguate_pronoun('The trophy would not fit in the brown suitcase because [it] was too big.')
'The trophy'
"""
assert hasattr(self.task, 'disambiguate_pronoun'), \
'roberta.disambiguate_pronoun() requires a model trained with the WSC task.'
with utils.eval(self.model):
return self.task.disambiguate_pronoun(self.model, sentence, use_cuda=self.device.type == 'cuda')
......@@ -35,6 +35,7 @@ class RobertaModel(FairseqLanguageModel):
'roberta.base': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz',
'roberta.large': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz',
'roberta.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz',
'roberta.large.wsc': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz',
}
def __init__(self, args, encoder):
......
......@@ -14,8 +14,6 @@ import os
import re
import sys
from tqdm import tqdm
from fairseq import distributed_utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
......@@ -208,6 +206,7 @@ class tqdm_progress_bar(progress_bar):
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
from tqdm import tqdm
self.tqdm = tqdm(iterable, self.prefix, leave=False)
def __iter__(self):
......
......@@ -104,6 +104,7 @@ class MaskedLMTask(FairseqTask):
eos=self.source_dictionary.eos(),
break_mode=self.args.sample_break_mode,
)
print('| loaded {} batches from: {}'.format(len(dataset), split_path))
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
......@@ -210,14 +211,3 @@ class MaskedLMTask(FairseqTask):
@property
def target_dictionary(self):
return self.dictionary
def get_average_masked_score(self, model, src_tokens, mask, **net_input):
"""Mask a set of tokens and return their average score."""
masked_tokens = src_tokens.clone()
masked_tokens[mask.byte()] = self.mask_idx
net_output = model(src_tokens=masked_tokens, **net_input, last_state_only=True)
lprobs = F.log_softmax(net_output[0], dim=-1, dtype=torch.float32)
lprobs = lprobs.gather(-1, src_tokens.unsqueeze(-1)).squeeze(-1)
mask = mask.type_as(lprobs)
score = (lprobs * mask).sum(dim=-1) / mask.sum(dim=-1)
return score
......@@ -12,14 +12,6 @@ from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask
@contextlib.contextmanager
def eval(model):
is_training = model.training
model.eval()
yield
model.train(is_training)
@register_task('translation_moe')
class TranslationMoETask(TranslationTask):
"""
......@@ -163,7 +155,7 @@ class TranslationMoETask(TranslationTask):
return lprob_yz
# compute responsibilities without dropout
with eval(model): # disable dropout
with utils.eval(model): # disable dropout
with torch.no_grad(): # disable autograd
lprob_yz = get_lprob_yz() # B x K
prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1)
......
......@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
import contextlib
import copy
import importlib.util
import math
......@@ -277,6 +278,10 @@ def import_user_module(args):
module_path = getattr(args, 'user_dir', None)
if module_path is not None:
module_path = os.path.abspath(args.user_dir)
if not os.path.exists(module_path):
fairseq_rel_path = os.path.join(os.path.dirname(__file__), '..', args.user_dir)
if os.path.exists(fairseq_rel_path):
module_path = fairseq_rel_path
module_parent, module_name = os.path.split(module_path)
if module_name not in sys.modules:
......@@ -339,3 +344,11 @@ def get_available_activation_fns() -> List:
'tanh',
'linear',
]
@contextlib.contextmanager
def eval(model):
is_training = model.training
model.eval()
yield
model.train(is_training)
......@@ -21,7 +21,7 @@ for _model_type, _cls in MODEL_REGISTRY.items():
for model_name in _cls.hub_models().keys():
globals()[model_name] = functools.partial(
_cls.from_pretrained,
model_name_or_path=model_name,
model_name,
)
# to simplify the interface we only expose named models
# globals()[_model_type] = _cls.from_pretrained
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment