# Fine-tuning BART on CNN-Dailymail summarization task ### 1) Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files with non-tokenized cased samples. ### 2) BPE preprocess: ```bash wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json' wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe' wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt' for SPLIT in train val do for LANG in source target do python -m examples.roberta.multiprocessing_bpe_encoder \ --encoder-json encoder.json \ --vocab-bpe vocab.bpe \ --inputs "cnn_dm/$SPLIT.$LANG" \ --outputs "cnn_dm/$SPLIT.bpe.$LANG" \ --workers 60 \ --keep-empty; done done ``` ### 3) Binarize dataset: ```bash fairseq-preprocess \ --source-lang "source" \ --target-lang "target" \ --trainpref "cnn_dm/train.bpe" \ --validpref "cnn_dm/val.bpe" \ --destdir "cnn_dm-bin/" \ --workers 60 \ --srcdict dict.txt \ --tgtdict dict.txt; ``` ### 4) Fine-tuning on CNN-DM summarization task: Example fine-tuning cmd ```bash TOTAL_NUM_UPDATES=20000 WARMUP_UPDATES=500 LR=3e-05 MAX_TOKENS=2048 UPDATE_FREQ=4 BART_PATH=/path/to/bart/model.pt CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py cnn_dm-bin \ --restore-file $BART_PATH \ --max-tokens $MAX_TOKENS \ --task translation \ --source-lang source --target-lang target \ --truncate-source \ --layernorm-embedding \ --share-all-embeddings \ --share-decoder-input-output-embed \ --reset-optimizer --reset-dataloader --reset-meters \ --required-batch-size-multiple 1 \ --arch bart_large \ --criterion label_smoothed_cross_entropy \ --label-smoothing 0.1 \ --dropout 0.1 --attention-dropout 0.1 \ --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \ --clip-norm 0.1 \ --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ --fp16 --update-freq $UPDATE_FREQ \ --skip-invalid-size-inputs-valid-test \ --find-unused-parameters; ``` Above is expected to run on `1` node with `8 32gb-V100`. Expected training time is about `5 hours`. Training time can be reduced with distributed training on `4` nodes and `--update-freq 1`. ### Inference for CNN-DM test data using above trained checkpoint. After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet: ```python from fairseq.models.bart import BARTModel bart = BARTModel.from_pretrained( 'checkpoints/', checkpoint_file='checkpoint_best.pt', data_name_or_path='cnn_dm-bin' ) bart.cuda() bart.eval() bart.half() count = 1 bsz = 32 with open('cnn_dm/test.source') as source, open('cnn_dm/test.hypo', 'w') as fout: sline = source.readline().strip() slines = [sline] for sline in source: if count % bsz == 0: with torch.no_grad(): hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) for hypothesis in hypotheses_batch: fout.write(hypothesis + '\n') fout.flush() slines = [] slines.append(sline.strip()) count += 1 if slines != []: hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) for hypothesis in hypotheses_batch: fout.write(hypothesis + '\n') fout.flush() ```