Commit 4fd2a16b authored by ngoyal2707's avatar ngoyal2707 Committed by Facebook Github Bot
Browse files

Bart push cnn eval

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/915

Differential Revision: D18580996

fbshipit-source-id: 9505a81892ba8ad997c03465d6a2d488c379c762
parent 9bf0f107
...@@ -12,6 +12,7 @@ Model | Description | # params | Download ...@@ -12,6 +12,7 @@ Model | Description | # params | Download
---|---|---|--- ---|---|---|---
`bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz) `bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz)
`bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz) `bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz)
`bart.large.cnn` | `bart.large` finetuned on `CNN-DM` | 400M | [bart.large.cnn.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz)
## Results ## Results
...@@ -32,7 +33,7 @@ Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1 ...@@ -32,7 +33,7 @@ Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
`bart.large` | 88.8/94.6 | 86.1/89.2 `bart.large` | 88.8/94.6 | 86.1/89.2
**[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)** **[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)**
_(dev set, no additional data used)_ _(test set, no additional data used)_
Model | R1 | R2 | RL Model | R1 | R2 | RL
---|---|---|--- ---|---|---|---
...@@ -150,6 +151,51 @@ with open('glue_data/MNLI/dev_matched.tsv') as fin: ...@@ -150,6 +151,51 @@ with open('glue_data/MNLI/dev_matched.tsv') as fin:
# Expected output: 0.9010 # Expected output: 0.9010
``` ```
#### Evaluating the `bart.large.cnn` model:
Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample.
```python
bart = torch.hub.load('pytorch/fairseq', 'bart.large.cnn')
bart.cuda()
bart.eval()
bart.half()
count = 1
bsz = 32
with open('test.source') as source, open('test.hypo', 'w') as fout:
sline = source.readline().strip()
slines = [sline]
for sline in source:
if count % bsz == 0:
with torch.no_grad():
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()
slines = []
slines.append(sline.strip())
count += 1
if slines != []:
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()
```
Install `files2rouge` from [here](https://github.com/pltrdy/files2rouge).
```bash
export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar
# Tokenize hypothesis and target files.
cat test.hypo | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.tokenized
cat test.target | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.target
files2rouge test.hypo.tokenized test.hypo.target
# Expected output: (ROUGE-2 Average_F: 0.21238)
```
## Finetuning ## Finetuning
- [Finetuning on GLUE](README.glue.md) - [Finetuning on GLUE](README.glue.md)
......
...@@ -3,11 +3,15 @@ ...@@ -3,11 +3,15 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import List
from fairseq import utils
from fairseq.data import encoders from fairseq.data import encoders
...@@ -25,6 +29,11 @@ class BARTHubInterface(nn.Module): ...@@ -25,6 +29,11 @@ class BARTHubInterface(nn.Module):
self.bpe = encoders.build_bpe(args) self.bpe = encoders.build_bpe(args)
self.max_positions = min(utils.resolve_max_positions(
self.task.max_positions(),
self.model.max_positions(),
))
# this is useful for determining the device # this is useful for determining the device
self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float)) self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float))
...@@ -52,7 +61,10 @@ class BARTHubInterface(nn.Module): ...@@ -52,7 +61,10 @@ class BARTHubInterface(nn.Module):
>>> bart.encode('world').tolist() >>> bart.encode('world').tolist()
[0, 8331, 2] [0, 8331, 2]
""" """
bpe_sentence = '<s> ' + self.bpe.encode(sentence) + ' </s>' tokens = self.bpe.encode(sentence)
if len(tokens.split(' ')) > self.max_positions - 2:
tokens = ' '.join(tokens.split(' ')[:self.max_positions - 2])
bpe_sentence = '<s> ' + tokens + ' </s>'
for s in addl_sentences: for s in addl_sentences:
bpe_sentence += (' </s>' if not no_separator else '') bpe_sentence += (' </s>' if not no_separator else '')
bpe_sentence += ' ' + self.bpe.encode(s) + ' </s>' bpe_sentence += ' ' + self.bpe.encode(s) + ' </s>'
...@@ -61,7 +73,7 @@ class BARTHubInterface(nn.Module): ...@@ -61,7 +73,7 @@ class BARTHubInterface(nn.Module):
def decode(self, tokens: torch.LongTensor): def decode(self, tokens: torch.LongTensor):
assert tokens.dim() == 1 assert tokens.dim() == 1
tokens = tokens.numpy() tokens = tokens.cpu().numpy()
if tokens[0] == self.task.source_dictionary.bos(): if tokens[0] == self.task.source_dictionary.bos():
tokens = tokens[1:] # remove <s> tokens = tokens[1:] # remove <s>
eos_mask = (tokens == self.task.source_dictionary.eos()) eos_mask = (tokens == self.task.source_dictionary.eos())
...@@ -72,6 +84,52 @@ class BARTHubInterface(nn.Module): ...@@ -72,6 +84,52 @@ class BARTHubInterface(nn.Module):
return sentences[0] return sentences[0]
return sentences return sentences
def _build_sample(self, src_tokens: List[torch.LongTensor]):
# assert torch.is_tensor(src_tokens)
dataset = self.task.build_dataset_for_inference(
src_tokens,
[x.numel() for x in src_tokens],
)
sample = dataset.collater(dataset)
sample = utils.apply_to_sample(
lambda tensor: tensor.to(self.device),
sample
)
return sample
def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> str:
input = [self.encode(sentence) for sentence in sentences]
hypos = self.generate(input, beam, verbose, **kwargs)
return [self.decode(x['tokens']) for x in hypos]
def generate(self, tokens: List[torch.LongTensor], beam: int = 5, verbose: bool = False, **kwargs) -> torch.LongTensor:
sample = self._build_sample(tokens)
# build generator using current args as well as any kwargs
gen_args = copy.copy(self.args)
gen_args.beam = beam
for k, v in kwargs.items():
setattr(gen_args, k, v)
generator = self.task.build_generator(gen_args)
translations = self.task.inference_step(
generator,
[self.model],
sample,
prefix_tokens=sample['net_input']['src_tokens'].new_zeros((len(tokens), 1)).fill_(self.task.source_dictionary.bos()),
)
if verbose:
src_str_with_unk = self.string(tokens)
print('S\t{}'.format(src_str_with_unk))
def getarg(name, default):
return getattr(gen_args, name, getattr(self.args, name, default))
# Process top predictions
hypos = [x[0] for x in translations]
hypos = [v for _, v in sorted(zip(sample['id'].tolist(), hypos))]
return hypos
def extract_features(self, tokens: torch.LongTensor, return_all_hiddens: bool = False) -> torch.Tensor: def extract_features(self, tokens: torch.LongTensor, return_all_hiddens: bool = False) -> torch.Tensor:
if tokens.dim() == 1: if tokens.dim() == 1:
tokens = tokens.unsqueeze(0) tokens = tokens.unsqueeze(0)
......
...@@ -28,6 +28,7 @@ class BARTModel(TransformerModel): ...@@ -28,6 +28,7 @@ class BARTModel(TransformerModel):
return { return {
'bart.large': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz', 'bart.large': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz',
'bart.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz', 'bart.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz',
'bart.large.cnn': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz',
} }
def __init__(self, args, encoder, decoder): def __init__(self, args, encoder, decoder):
...@@ -172,6 +173,13 @@ class BARTModel(TransformerModel): ...@@ -172,6 +173,13 @@ class BARTModel(TransformerModel):
for k in keys_to_delete: for k in keys_to_delete:
del state_dict[k] del state_dict[k]
# When finetuning on translation task, remove last row of
# embedding matrix that corresponds to mask_idx token.
if self.args.task == 'translation':
dict_size = state_dict['encoder.embed_tokens.weight'].size(0)
state_dict['encoder.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight'][:dict_size-1, :]
state_dict['decoder.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight'][:dict_size-1, :]
# Copy any newly-added classification heads into the state dict # Copy any newly-added classification heads into the state dict
# with their current weights. # with their current weights.
if hasattr(self, 'classification_heads'): if hasattr(self, 'classification_heads'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment