"src/turbomind/vscode:/vscode.git/clone" did not exist on "2e5285800ba0d7665c3f7480fdf5d87a29e3670c"
Commit 4fd2a16b authored by ngoyal2707's avatar ngoyal2707 Committed by Facebook Github Bot
Browse files

Bart push cnn eval

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/915

Differential Revision: D18580996

fbshipit-source-id: 9505a81892ba8ad997c03465d6a2d488c379c762
parent 9bf0f107
......@@ -12,6 +12,7 @@ Model | Description | # params | Download
---|---|---|---
`bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz)
`bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz)
`bart.large.cnn` | `bart.large` finetuned on `CNN-DM` | 400M | [bart.large.cnn.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz)
## Results
......@@ -32,7 +33,7 @@ Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
`bart.large` | 88.8/94.6 | 86.1/89.2
**[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)**
_(dev set, no additional data used)_
_(test set, no additional data used)_
Model | R1 | R2 | RL
---|---|---|---
......@@ -150,6 +151,51 @@ with open('glue_data/MNLI/dev_matched.tsv') as fin:
# Expected output: 0.9010
```
#### Evaluating the `bart.large.cnn` model:
Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample.
```python
bart = torch.hub.load('pytorch/fairseq', 'bart.large.cnn')
bart.cuda()
bart.eval()
bart.half()
count = 1
bsz = 32
with open('test.source') as source, open('test.hypo', 'w') as fout:
sline = source.readline().strip()
slines = [sline]
for sline in source:
if count % bsz == 0:
with torch.no_grad():
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()
slines = []
slines.append(sline.strip())
count += 1
if slines != []:
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()
```
Install `files2rouge` from [here](https://github.com/pltrdy/files2rouge).
```bash
export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar
# Tokenize hypothesis and target files.
cat test.hypo | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.tokenized
cat test.target | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.target
files2rouge test.hypo.tokenized test.hypo.target
# Expected output: (ROUGE-2 Average_F: 0.21238)
```
## Finetuning
- [Finetuning on GLUE](README.glue.md)
......
......@@ -3,11 +3,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from fairseq import utils
from fairseq.data import encoders
......@@ -25,6 +29,11 @@ class BARTHubInterface(nn.Module):
self.bpe = encoders.build_bpe(args)
self.max_positions = min(utils.resolve_max_positions(
self.task.max_positions(),
self.model.max_positions(),
))
# this is useful for determining the device
self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float))
......@@ -52,7 +61,10 @@ class BARTHubInterface(nn.Module):
>>> bart.encode('world').tolist()
[0, 8331, 2]
"""
bpe_sentence = '<s> ' + self.bpe.encode(sentence) + ' </s>'
tokens = self.bpe.encode(sentence)
if len(tokens.split(' ')) > self.max_positions - 2:
tokens = ' '.join(tokens.split(' ')[:self.max_positions - 2])
bpe_sentence = '<s> ' + tokens + ' </s>'
for s in addl_sentences:
bpe_sentence += (' </s>' if not no_separator else '')
bpe_sentence += ' ' + self.bpe.encode(s) + ' </s>'
......@@ -61,7 +73,7 @@ class BARTHubInterface(nn.Module):
def decode(self, tokens: torch.LongTensor):
assert tokens.dim() == 1
tokens = tokens.numpy()
tokens = tokens.cpu().numpy()
if tokens[0] == self.task.source_dictionary.bos():
tokens = tokens[1:] # remove <s>
eos_mask = (tokens == self.task.source_dictionary.eos())
......@@ -72,6 +84,52 @@ class BARTHubInterface(nn.Module):
return sentences[0]
return sentences
def _build_sample(self, src_tokens: List[torch.LongTensor]):
# assert torch.is_tensor(src_tokens)
dataset = self.task.build_dataset_for_inference(
src_tokens,
[x.numel() for x in src_tokens],
)
sample = dataset.collater(dataset)
sample = utils.apply_to_sample(
lambda tensor: tensor.to(self.device),
sample
)
return sample
def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> str:
input = [self.encode(sentence) for sentence in sentences]
hypos = self.generate(input, beam, verbose, **kwargs)
return [self.decode(x['tokens']) for x in hypos]
def generate(self, tokens: List[torch.LongTensor], beam: int = 5, verbose: bool = False, **kwargs) -> torch.LongTensor:
sample = self._build_sample(tokens)
# build generator using current args as well as any kwargs
gen_args = copy.copy(self.args)
gen_args.beam = beam
for k, v in kwargs.items():
setattr(gen_args, k, v)
generator = self.task.build_generator(gen_args)
translations = self.task.inference_step(
generator,
[self.model],
sample,
prefix_tokens=sample['net_input']['src_tokens'].new_zeros((len(tokens), 1)).fill_(self.task.source_dictionary.bos()),
)
if verbose:
src_str_with_unk = self.string(tokens)
print('S\t{}'.format(src_str_with_unk))
def getarg(name, default):
return getattr(gen_args, name, getattr(self.args, name, default))
# Process top predictions
hypos = [x[0] for x in translations]
hypos = [v for _, v in sorted(zip(sample['id'].tolist(), hypos))]
return hypos
def extract_features(self, tokens: torch.LongTensor, return_all_hiddens: bool = False) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
......
......@@ -28,6 +28,7 @@ class BARTModel(TransformerModel):
return {
'bart.large': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz',
'bart.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz',
'bart.large.cnn': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz',
}
def __init__(self, args, encoder, decoder):
......@@ -172,6 +173,13 @@ class BARTModel(TransformerModel):
for k in keys_to_delete:
del state_dict[k]
# When finetuning on translation task, remove last row of
# embedding matrix that corresponds to mask_idx token.
if self.args.task == 'translation':
dict_size = state_dict['encoder.embed_tokens.weight'].size(0)
state_dict['encoder.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight'][:dict_size-1, :]
state_dict['decoder.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight'][:dict_size-1, :]
# Copy any newly-added classification heads into the state dict
# with their current weights.
if hasattr(self, 'classification_heads'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment