Commit 226c1f48 authored by ngoyal2707's avatar ngoyal2707 Committed by Facebook Github Bot
Browse files

added instructions to FT bart on cnn-dm

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/922

Differential Revision: D18617322

fbshipit-source-id: 50645197cb7f075b5f878818a97358653077c3e0
parent 99fbd317
# Fine-tuning BART on CNN-Dailymail summarization task
### 1) Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files with non-tokenized cased samples.
### 2) BPE preprocess:
```bash
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
for SPLIT in train val
do
for LANG in source target
do
python -m examples.roberta.multiprocessing_bpe_encoder \
--encoder-json encoder.json \
--vocab-bpe vocab.bpe \
--inputs "cnn_dm/$SPLIT.$LANG" \
--outputs "cnn_dm/$SPLIT.bpe.$LANG" \
--workers 60 \
--keep-empty;
done
done
```
### 3) Binarize dataset:
```bash
fairseq-preprocess \
--source-lang "source" \
--target-lang "target" \
--trainpref "cnn_dm/train.bpe" \
--validpref "cnn_dm/val.bpe" \
--destdir "cnn_dm-bin/" \
--workers 60 \
--srcdict dict.txt \
--tgtdict dict.txt;
```
### 4) Fine-tuning on CNN-DM summarization task:
Example fine-tuning cmd
```bash
TOTAL_NUM_UPDATES=20000
WARMUP_UPDATES=500
LR=3e-05
MAX_TOKENS=2048
UPDATE_FREQ=4
BART_PATH=/path/to/bart/model.pt
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py cnn_dm-bin \
--restore-file $BART_PATH \
--max-tokens $MAX_TOKENS \
--task translation \
--source-lang source --target-lang target \
--layernorm-embedding \
--share-all-embeddings \
--share-decoder-input-output-embed \
--reset-optimizer --reset-dataloader --reset-meters \
--required-batch-size-multiple 1 \
--arch bart_large \
--criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 \
--dropout 0.1 --attention-dropout 0.1 \
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
--clip-norm 0.1 \
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
--fp16 --update-freq $UPDATE_FREQ \
--skip-invalid-size-inputs-valid-test\
--find-unused-parameters;
```
Above is expected to run on `1` node with `8 32gb-V100`.
Expected training time is about `5 hours`. Training time can be reduced with distributed training on `4` nodes and `--update-freq 1`.
### Inference for CNN-DM test data using above trained checkpoint.
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
```python
from fairseq.models.bart import BARTModel
bart = BARTModel.from_pretrained(
'checkpoints/',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='cnn_dm-bin'
)
bart.cuda()
bart.eval()
bart.half()
count = 1
bsz = 32
with open('cnn_dm/test.source') as source, open('cnn_dm/test.hypo', 'w') as fout:
sline = source.readline().strip()
slines = [sline]
for sline in source:
if count % bsz == 0:
with torch.no_grad():
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()
slines = []
slines.append(sline.strip())
count += 1
if slines != []:
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
for hypothesis in hypotheses_batch:
fout.write(hypothesis + '\n')
fout.flush()
```
......@@ -199,6 +199,7 @@ files2rouge test.hypo.tokenized test.hypo.target
## Finetuning
- [Finetuning on GLUE](README.glue.md)
- [Finetuning on CNN-DM](README.cnn.md)
## Citation
......
......@@ -147,7 +147,6 @@ class DenoisingDataset(FairseqDataset):
ps = torch.FloatTensor(ps)
self.mask_span_distribution = torch.distributions.Categorical(ps)
self.verbose = args.verbose
self.epoch = 0
def set_epoch(self, epoch, **unused):
......
......@@ -42,14 +42,6 @@ class BARTModel(TransformerModel):
@staticmethod
def add_args(parser):
super(BARTModel, BARTModel).add_args(parser)
parser.add_argument(
'--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence'
)
parser.add_argument(
'--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence'
)
parser.add_argument(
'--pooler-dropout', type=float, metavar='D',
help='dropout probability in the masked_lm pooler layers'
......@@ -175,10 +167,10 @@ class BARTModel(TransformerModel):
# When finetuning on translation task, remove last row of
# embedding matrix that corresponds to mask_idx token.
if self.args.task == 'translation':
dict_size = state_dict['encoder.embed_tokens.weight'].size(0)
state_dict['encoder.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight'][:dict_size-1, :]
state_dict['decoder.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight'][:dict_size-1, :]
loaded_dict_size = state_dict['encoder.embed_tokens.weight'].size(0)
if loaded_dict_size == len(self.encoder.dictionary) + 1 and '<mask>' not in self.encoder.dictionary:
state_dict['encoder.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight'][:loaded_dict_size-1, :]
state_dict['decoder.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight'][:loaded_dict_size-1, :]
# Copy any newly-added classification heads into the state dict
# with their current weights.
......
......@@ -68,13 +68,21 @@ class DenoisingTask(FairseqTask):
)
parser.add_argument(
'--mask-length', default="subword", type=str,
choices=['subword', 'word', 'span-possion'],
choices=['subword', 'word', 'span-poisson'],
help='mask length to choose'
)
parser.add_argument(
'--replace-length', default=-1, type=int,
help='when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)'
)
parser.add_argument(
'--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence'
)
parser.add_argument(
'--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence'
)
def __init__(self, args, dictionary):
super().__init__(args)
......@@ -94,7 +102,7 @@ class DenoisingTask(FairseqTask):
args.shuffle_instance = False
return cls(args, dictionary)
def load_dataset(self, split, epoch=0, combine=False):
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
"""Load a given dataset split.
Args:
......
......@@ -8,11 +8,14 @@ import os
from fairseq import options, utils
from fairseq.data import (
AppendTokenDataset,
ConcatDataset,
data_utils,
indexed_dataset,
LanguagePairDataset,
PrependTokenDataset,
StripTokenDataset,
TruncateDataset,
)
from . import FairseqTask, register_task
......@@ -25,6 +28,7 @@ def load_langpair_dataset(
combine, dataset_impl, upsample_primary,
left_pad_source, left_pad_target, max_source_positions,
max_target_positions, prepend_bos=False, load_alignments=False,
truncate_source=False,
):
def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
......@@ -47,9 +51,16 @@ def load_langpair_dataset(
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
src_datasets.append(
data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl)
)
src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl)
if truncate_source:
src_dataset = AppendTokenDataset(
TruncateDataset(
StripTokenDataset(src_dataset, src_dict.eos()),
max_source_positions - 1,
),
src_dict.eos(),
)
src_datasets.append(src_dataset)
tgt_datasets.append(
data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl)
)
......@@ -139,6 +150,8 @@ class TranslationTask(FairseqTask):
help='max number of tokens in the target sequence')
parser.add_argument('--upsample-primary', default=1, type=int,
help='amount to upsample primary dataset')
parser.add_argument('--truncate-source', default=False, action='store_true',
help='boolean to truncate source to max-source-positions')
# fmt: on
def __init__(self, args, src_dict, tgt_dict):
......@@ -203,6 +216,7 @@ class TranslationTask(FairseqTask):
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
load_alignments=self.args.load_alignments,
truncate_source=self.args.truncate_source,
)
def build_dataset_for_inference(self, src_tokens, src_lengths):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment