Unverified Commit b54ef78d authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Bart-CNN (#3059)

`generate` code that produces 99% identical summarizations to fairseq on CNN test data, with caching.
parent 6b1ff250
......@@ -4,20 +4,27 @@ Bart
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
@sshleifer
The Bart model was `proposed <https://arxiv.org/abs/1910.13461>`_ by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov, Luke Zettlemoyer on 29 Oct, 2019.
It is a sequence to sequence model where both encoder and decoder are transformers. The paper also introduces a novel pretraining objective, and demonstrates excellent summarization results.
The authors released their code `here <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_
Paper
~~~~~
The Bart model was `proposed <https://arxiv.org/abs/1910.13461>`_ by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer on 29 Oct, 2019.
According to the abstract:
**Abstract:**
- Bart uses a standard seq2seq/machine translation architecture with a bidirectional encoder (like BERT) and a left-to-right decoder (like GPT).
- The pretraining task involves randomly shuffling the order of the original sentences and a novel in-filling scheme, where spans of text are replaced with a single mask token.
- BART is particularly effective when fine tuned for text generation but also works well for comprehension tasks. It matches the performance of RoBERTa with comparable training resources on GLUE and SQuAD, achieves new state-of-the-art results on a range of abstractive dialogue, question answering, and summarization tasks, with gains of up to 6 ROUGE.
*We present BART, a denoising autoencoder for pretraining sequence-to-sequence models. BART is trained by (1) corrupting text with an arbitrary noising function, and (2) learning a model to reconstruct the original text. It uses a standard Tranformer-based neural machine translation architecture which, despite its simplicity, can be seen as generalizing BERT (due to the bidirectional encoder), GPT (with the left-to-right decoder), and many other more recent pretraining schemes. We evaluate a number of noising approaches, finding the best performance by both randomly shuffling the order of the original sentences and using a novel in-filling scheme, where spans of text are replaced with a single mask token. BART is particularly effective when fine tuned for text generation but also works well for comprehension tasks. It matches the performance of RoBERTa with comparable training resources on GLUE and SQuAD, achieves new state-of-the-art results on a range of abstractive dialogue, question answering, and summarization tasks, with gains of up to 6 ROUGE. BART also provides a 1.1 BLEU increase over a back-translation system for machine translation, with only target language pretraining. We also report ablation experiments that replicate other pretraining schemes within the BART framework, to better measure which factors most influence end-task performance.*
`BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension`
The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_
Notes:
Implementation Notes
~~~~~~~~~~~~~~~~~~~~
- Bart doesn't use :obj:`token_type_ids`, for sequence classification just use BartTokenizer.encode to get the proper splitting.
- Inputs to the decoder are created by BartModel.forward if they are not passed. This is different than some other model APIs.
- Model predictions are intended to be identical to the original implementation. This only works, however, if the string you pass to fairseq.encode starts with a space.
- Decoder inputs are created automatically by the helper function ``transformers.modeling_bart._prepare_bart_decoder_inputs``
BartModel
- ``MaskedLM.generate`` should be used for summarization, see the example in that docstrings
BartModel
~~~~~~~~~~~~~~~~~~~~
......@@ -30,7 +37,7 @@ BartForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartForMaskedLM
:members: forward
:members: forward, generate
BartForSequenceClassification
......
......@@ -280,7 +280,10 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters |
| | | | bart-large base architecture with a classification head |
| | | | bart-large base architecture with a classification head, finetuned on MNLI |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bart-large-cnn`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters (same as base) |
| | | | bart-large base architecture finetuned on cnn summarization task |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
......
......@@ -26,7 +26,7 @@ _bart_large_url = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bart-large": _bart_large_url,
"bart-large-mnli": _bart_large_url, # fine as same
"bart-cnn": None, # not done
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
}
......@@ -59,6 +59,7 @@ class BartConfig(PretrainedConfig):
classifier_dropout=0.0,
output_past=False,
num_labels=3,
bos_token_id=0,
**common_kwargs
):
r"""
......@@ -67,12 +68,16 @@ class BartConfig(PretrainedConfig):
config = BartConfig.from_pretrained('bart-large')
model = BartModel(config)
"""
super().__init__(num_labels=num_labels, output_past=output_past, pad_token_id=pad_token_id, **common_kwargs)
super().__init__(
num_labels=num_labels,
output_past=output_past,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
**common_kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
self.eos_token_id = eos_token_id
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = self.num_hidden_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
......
......@@ -23,9 +23,11 @@ import fairseq
import torch
from packaging import version
from transformers import BartConfig, BartForSequenceClassification, BartModel, BartTokenizer
from transformers import BartConfig, BartForMaskedLM, BartForSequenceClassification, BartModel, BartTokenizer
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn"]
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0")
......@@ -33,7 +35,7 @@ if version.parse(fairseq.__version__) < version.parse("0.9.0"):
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SAMPLE_TEXT = "Hello world! cécé herlolip"
SAMPLE_TEXT = " Hello world! cécé herlolip"
rename_keys = [
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
......@@ -41,7 +43,7 @@ rename_keys = [
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
]
IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version"]
IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version", "_float_tensor"]
def rename_key(dct, old, new):
......@@ -53,36 +55,45 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our BERT structure.
"""
b2 = torch.hub.load("pytorch/fairseq", checkpoint_path)
b2.eval() # disable dropout
b2.model.upgrade_state_dict(b2.model.state_dict())
config = BartConfig()
tokens = b2.encode(SAMPLE_TEXT).unsqueeze(0)
tokens2 = BartTokenizer.from_pretrained("bart-large").encode(SAMPLE_TEXT).unsqueeze(0)
bart = torch.hub.load("pytorch/fairseq", checkpoint_path)
bart.eval() # disable dropout
bart.model.upgrade_state_dict(bart.model.state_dict())
hf_model_name = checkpoint_path.replace(".", "-")
config = BartConfig.from_pretrained(hf_model_name)
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
tokens2 = BartTokenizer.from_pretrained(hf_model_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
assert torch.eq(tokens, tokens2).all()
# assert their_output.size() == (1, 11, 1024)
if checkpoint_path == "bart.large":
state_dict = b2.model.state_dict()
if checkpoint_path in ["bart.large", "bart.large.cnn"]:
state_dict = bart.model.state_dict()
for k in IGNORE_KEYS:
state_dict.pop(k, None)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
model = BartModel(config)
their_output = b2.extract_features(tokens)
their_output = bart.extract_features(tokens)
else: # MNLI Case
state_dict = b2.state_dict()
state_dict = bart.state_dict()
for k in IGNORE_KEYS:
state_dict.pop(k, None)
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
state_dict.pop("_float_tensor", None)
model = BartForSequenceClassification(config)
their_output = b2.predict("mnli", tokens, return_logits=True)
for k in IGNORE_KEYS:
state_dict.pop(k, None)
their_output = bart.predict("mnli", tokens, return_logits=True)
# Load state dict
model.load_state_dict(state_dict)
model.eval()
our_outputs = model.forward(tokens)[0]
# Check results
if checkpoint_path == "bart.large.cnn": # generate doesnt work yet
model = BartForMaskedLM(config, base_model=model)
assert "lm_head.weight" in model.state_dict()
assert model.lm_head.out_features == config.max_position_embeddings
model.eval()
our_outputs = model.model.forward(tokens)[0]
else:
our_outputs = model.forward(tokens)[0]
assert their_output.shape == our_outputs.shape
assert (their_output == our_outputs).all().item()
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
......@@ -92,7 +103,8 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("fairseq_path", choices=["bart.large", "bart.large.mnli"], type=str, help="")
parser.add_argument("fairseq_path", choices=FAIRSEQ_MODELS, type=str, help="")
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
convert_bart_checkpoint(
......
......@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BART model, ported from the fairseq repo."""
import logging
import math
import random
from typing import Dict, List, Optional, Tuple
......@@ -24,7 +24,7 @@ from torch import Tensor, nn
from .configuration_bart import BartConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
from .modeling_utils import BeamHypotheses, PreTrainedModel, create_position_ids_from_input_ids
logger = logging.getLogger(__name__)
......@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/pytorch_model.bin",
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
}
BART_START_DOCSTRING = r"""
......@@ -332,7 +333,7 @@ class DecoderLayer(nn.Module):
x,
encoder_hidden_states,
encoder_attn_mask=None,
decoder_cached_states=None,
layer_state=None,
attention_mask=None,
need_attn_weights=False,
):
......@@ -348,43 +349,28 @@ class DecoderLayer(nn.Module):
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if decoder_cached_states is None:
prev_self_attn_state, prev_attn_state = (None, None)
else:
assert len(decoder_cached_states) == 3
prev_self_attn_state, prev_attn_state = (
decoder_cached_states["self"],
decoder_cached_states["encoder_decoder"],
)
residual = x
if prev_self_attn_state is not None:
saved_state = prev_self_attn_state
decoder_cached_states["self"] = saved_state
y = x # TODO(SS): figure out why fairseq did this, then hopefully delete it
if layer_state is None:
layer_state = {}
# next line mutates layer state
x, self_attn_weights = self.self_attn.forward(
query=x,
key=y,
value=y,
decoder_cached_states=decoder_cached_states,
need_weights=need_attn_weights,
attn_mask=attention_mask,
query=x, key=y, value=y, layer_state=layer_state, need_weights=need_attn_weights, attn_mask=attention_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
assert self.encoder_attn.cache_key != self.self_attn.cache_key
if prev_attn_state is not None:
saved_state = prev_attn_state
decoder_cached_states["encoder_decoder"] = saved_state
x, encoder_attn_weights = self.encoder_attn.forward(
query=x,
key=encoder_hidden_states, # could be None
value=encoder_hidden_states,
key_padding_mask=encoder_attn_mask,
decoder_cached_states=decoder_cached_states,
layer_state=layer_state, # mutates layer state
static_kv=True,
need_weights=False, # not returning it so why compute it
)
......@@ -403,15 +389,8 @@ class DecoderLayer(nn.Module):
return (
x,
self_attn_weights,
decoder_cached_states,
) # just self_attn weights for now, following t5, decoder_cached_states = cache for decoding
def _past_to_dict(self, prev_attn_state):
prev_key, prev_value = prev_attn_state[:2]
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
return saved_state
layer_state,
) # just self_attn weights for now, following t5, layer_state = cache for decoding
class BartDecoder(nn.Module):
......@@ -440,6 +419,7 @@ class BartDecoder(nn.Module):
[DecoderLayer(config) for _ in range(config.decoder_layers)]
) # type: List[DecoderLayer]
self.layernorm_embedding = LayerNorm(config.d_model)
self.generation_mode = False
def forward(
self,
......@@ -469,10 +449,14 @@ class BartDecoder(nn.Module):
- attentions
"""
# embed positions
positions = self.embed_positions(input_ids)
x = self.embed_tokens(input_ids)
positions = self.embed_positions.forward(input_ids, generation_mode=self.generation_mode)
if positions is not None:
if self.generation_mode:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them
assert input_ids.ne(self.padding_idx).any()
x = self.embed_tokens(input_ids)
x += positions
x = self.layernorm_embedding(x)
......@@ -489,17 +473,19 @@ class BartDecoder(nn.Module):
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
x, layer_self_attn, layer_past = decoder_layer.forward(
x,
encoder_hidden_states,
encoder_padding_mask,
decoder_cached_states=layer_state,
layer_state=layer_state,
attention_mask=combined_mask,
need_attn_weights=self.output_attentions,
)
if self.output_past:
next_decoder_cache.append(layer_past)
next_decoder_cache.append(layer_past.copy())
if self.output_hidden_states:
all_hidden_states += (x,)
if self.output_attentions:
......@@ -509,7 +495,22 @@ class BartDecoder(nn.Module):
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
x = x.transpose(0, 1)
return x, next_decoder_cache, all_hidden_states, list(all_self_attns)
if self.output_past:
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
else:
next_cache = None
return x, next_cache, all_hidden_states, list(all_self_attns)
def reorder_attn_buffer(input_buffer, new_order):
"""Reorder buffered internal state (for incremental generation)."""
# input_buffer = self._get_input_buffer(incremental_state)
for k in input_buffer.keys():
input_buffer_k = input_buffer[k]
if input_buffer_k is not None:
input_buffer[k] = input_buffer_k.index_select(0, new_order)
# incremental_state = self._set_input_buffer(incremental_state, input_buffer)
return input_buffer
class SelfAttention(nn.Module):
......@@ -557,7 +558,7 @@ class SelfAttention(nn.Module):
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
decoder_cached_states: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
layer_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = False,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
......@@ -579,8 +580,8 @@ class SelfAttention(nn.Module):
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
# get here for encoder decoder cause of static_kv
if decoder_cached_states is not None: # get the last k,v and mask for reuse
saved_state = decoder_cached_states.get(self.cache_key, {})
if layer_state is not None: # get the last k,v and mask for reuse
saved_state = layer_state.get(self.cache_key, {})
if "prev_key" in saved_state:
# previous time steps are cached - no need to recompute key and value if they are static
if static_kv:
......@@ -588,6 +589,7 @@ class SelfAttention(nn.Module):
key = value = None
else:
saved_state = None
layer_state = {}
q = self.q_proj(query) * self.scaling
if self.encoder_decoder_attention:
......@@ -608,17 +610,16 @@ class SelfAttention(nn.Module):
v = self._shape(v, -1, bsz)
if saved_state is not None:
k, v, key_padding_mask, new_state = self._use_and_update_saved_state(
k, v, saved_state, key_padding_mask, static_kv, bsz
)
saved_state.update(
{
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
# assert self.cache_key != 'encoder_decoder' or key_padding_mask is None
# Update cache
layer_state[self.cache_key] = {
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
"prev_key_padding_mask": key_padding_mask,
"prev_key_padding_mask": key_padding_mask if not static_kv else None,
}
)
decoder_cached_states[self.cache_key] = saved_state # Update cache
assert k is not None
src_len = k.size(1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
......@@ -632,7 +633,7 @@ class SelfAttention(nn.Module):
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len)
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,)
if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
......@@ -650,7 +651,7 @@ class SelfAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
return attn_output, attn_weights
def _use_and_update_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
......@@ -675,7 +676,7 @@ class SelfAttention(nn.Module):
key_padding_mask = self._cat_prev_key_padding_mask(
key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv
)
return k, v, key_padding_mask, saved_state
return k, v, key_padding_mask
@staticmethod
def _cat_prev_key_padding_mask(
......@@ -693,7 +694,6 @@ class SelfAttention(nn.Module):
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current is None
elif prev_key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
if prev_key_padding_mask.is_cuda:
filler = filler.cuda()
......@@ -747,8 +747,12 @@ class LearnedPositionalEmbedding(nn.Embedding):
num_embeddings += padding_idx + 1 # WHY?
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def forward(self, input):
def forward(self, input, generation_mode=False):
"""Input is expected to be of size [bsz x seqlen]."""
if generation_mode: # the position is our current step in the decoded sequence
pos = int(self.padding_idx + input.size(1))
positions = input.data.new(1, 1).fill_(pos)
else:
positions = create_position_ids_from_input_ids(input, self.padding_idx)
return super().forward(positions)
......@@ -826,21 +830,20 @@ class BartModel(PretrainedBartModel):
assert attention_mask.max() <= 0
# make masks if user doesn't supply
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(
if not self.decoder.generation_mode:
decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs(
self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_attn_mask=decoder_attention_mask,
)
assert decoder_input_ids is not None
if encoder_outputs is None:
# TODO(SS): make this caching more usable when overwrite generate
encoder_outputs = self.encoder.forward(input_ids=input_ids, attention_mask=attention_mask)
assert isinstance(encoder_outputs, tuple)
# dec_features, decoder_cached_states, dec_hidden, dec_attn
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
decoder_outputs = self.decoder.forward(
decoder_input_ids,
encoder_outputs[0],
attention_mask,
decoder_attn_mask,
decoder_attention_mask,
decoder_cached_states=decoder_cached_states,
)
# Attention and hidden_states will be [] or None if they aren't needed
......@@ -856,20 +859,26 @@ class BartModel(PretrainedBartModel):
self.shared = value
def get_output_embeddings(self):
return _make_linear_from_emb(self.shared)
return _make_linear_from_emb(self.shared) # make it on the fly
@add_start_docstrings(
"The bare BART Model with a language modeling head", BART_START_DOCSTRING,
"The bare BART Model with a language modeling head. This is the model used for summarization.",
BART_START_DOCSTRING,
)
class BartForMaskedLM(PretrainedBartModel):
base_model_prefix = "model"
def __init__(self, config: BartConfig):
super().__init__(config)
self.model = BartModel(config)
# if base_model is None:
base_model = BartModel(config)
self.model = base_model
self.lm_head = _make_linear_from_emb(self.model.shared)
def tie_weights(self):
pass # hack to prevent changing lm_head.out_features. The input and output embeddings are still the same.
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
def forward(
self,
......@@ -935,12 +944,309 @@ class BartForMaskedLM(PretrainedBartModel):
return outputs
@staticmethod
def prepare_inputs_for_generation(input_ids, past, **kwargs):
return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids[:, -1:]}
def prepare_inputs_for_generation(input_ids, past, decoder_input_ids, attention_mask):
if past is None: # first step
encoder_outputs, decoder_cached_states = None, None
else:
encoder_outputs, decoder_cached_states = past
return {
"input_ids": input_ids, # ignored after first pass
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
# "decoder_attention_mask": decoder_attention_mask,
}
@staticmethod
def _reorder_cache(past, beam_idx):
((enc_out, enc_mask), decoder_cached_states) = past
reordered_past = []
for layer_past in decoder_cached_states:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {
attn_key: reorder_attn_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
}
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
reordered_past.append(layer_past_new)
new_enc_out = enc_out if enc_out is None else enc_out.index_select(1, beam_idx)
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
past = ((new_enc_out, new_enc_mask), reordered_past)
return past
def get_output_embeddings(self):
return self.lm_head
@torch.no_grad()
def generate(
self,
input_ids,
attention_mask=None,
max_length=20,
num_beams=1,
repetition_penalty=1.0,
length_penalty=1.0,
num_return_sequences=1,
min_len=0,
no_repeat_ngram_size=0,
):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
Adapted in part from Facebook's `XLM beam search code`_ and `Fairseq beam search code`_.
.. _`XLM beam search code`:
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
.. _`Fairseq beam search code`:
https://github.com/pytorch/fairseq/blob/master/fairseq/sequence_generator.py
Parameters:
input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape `(1,)`.
max_length: (`optional`) int
The max length of the sequence to be generated. Does not include tokens in input_ids.
num_beams: (`optional`) int
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
repetition_penalty: (`optional`) float
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
length_penalty: (`optional`) float
Exponential penalty to the length. Default to 1.
num_return_sequences: (`optional`) int
The number of independently computed returned sequences for each element in the batch. Default to 1.
min_len: (`optional`) int
Returns:
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
sequence_length is <= max_length (examples can finish early)
Examples::
config = BartConfig(vocab_size=50264, output_past=True)
model = AutoModelWithLMHead.from_pretrained('bart-large-cnn', config=config)
tokenizer = AutoTokenizer.from_pretrained('bart-large-cnn')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary
generated_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in generated_ids])
"""
bos_token_id = self.config.bos_token_id
pad_token_id = self.config.pad_token_id
eos_token_id = self.config.eos_token_id
batch_size, cur_len = input_ids.shape
assert input_ids is not None
assert self.config.output_past, "Generating with bart requires instantiating a config with output_past=True"
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
assert isinstance(pad_token_id, int)
assert bos_token_id == 0, "configurable bos_token_id not yet supported"
assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0
), "`num_return_sequences` should be a positive integer."
# current position and vocab size
cur_len = input_ids.shape[1]
vocab_size = self.config.vocab_size
if num_return_sequences != 1:
# Expand input to num return sequences
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(
batch_size * num_return_sequences, cur_len
) # shape: (batch_size * num_return_sequences, cur_len)
batch_size *= num_return_sequences
# Below here somewhat similar to PretrainedModel._generate_beam_search
# Expand input to num beams
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
if attention_mask is not None:
attention_mask = (
attention_mask.unsqueeze(1)
.expand(batch_size, num_beams, cur_len)
.contiguous()
.view(batch_size * num_beams, cur_len)
) # RESHAPE
# generated hypotheses
finalized_hyps = [ # they end in EOS and we wont work on them more!
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=True) for _ in range(batch_size)
]
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9 # avoid ties in first step
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# decoder tokens
prev_output_tokens = input_ids.new(batch_size * num_beams, 1).long().fill_(-1)
prev_output_tokens[:, 0] = 2 # HARDCODED EOS, which will be removed at the end.
decoder_cache = None
done = [False for _ in range(batch_size)] # done sentences
self.model.decoder.generation_mode = True # tells decoder not to use causal mask
for step in range(max_length + 1):
decoder_input_ids = prev_output_tokens.clone()
model_inputs = self.prepare_inputs_for_generation(
input_ids, decoder_cache, decoder_input_ids, attention_mask,
)
outputs = self(**model_inputs)
lprobs = F.log_softmax(outputs[0][:, -1, :], dim=-1)
lprobs[lprobs != lprobs] = -math.inf # block nans
lprobs[:, pad_token_id] = -math.inf
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
if step == 0: # Force BOS to be chosen
lprobs[:, bos_token_id + 1 :] = -math.inf
elif step < min_len: # Prevent EOS from being chosen
lprobs[:, eos_token_id] = -math.inf
elif step == max_length: # FORCE EOS to be chosen
lprobs[:, :eos_token_id] = -math.inf
lprobs[:, eos_token_id + 1 :] = -math.inf
assert self._do_output_past(outputs)
decoder_cache = outputs[1]
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty)
num_hypos = batch_size * num_beams
if no_repeat_ngram_size > 0: # copied from fairseq
# for each sentence, calculate a list of banned tokens to prevent repetitively generating the same ngrams
banned_tokens = self.calc_banned_tokens(prev_output_tokens, num_hypos, no_repeat_ngram_size, step)
# then set their probabilities tof -inf
for idx in range(num_hypos):
lprobs[idx, banned_tokens[idx]] = -math.inf
assert lprobs.size() == (batch_size * num_beams, vocab_size)
_scores = lprobs + beam_scores[:, None].expand_as(lprobs) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis across beams)
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
# Take the best 2 x beam_size predictions for each example, we'll choose the first beam_size of these which don't predict eos to continue with.
next_scores, next_words = torch.topk(_scores, 2 * num_beams)
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
# list of (batch_size * num_beams)
next_batch_beam = [] # Tuple(next score, next word, current position in the batch)
for batch_idx in range(batch_size):
# if we are done with this sentence (because we can't improve)
if done[batch_idx]: # then pad all associated hypotheses
assert (
len(finalized_hyps[batch_idx]) >= num_beams
), "Example can only be done if at least {} beams have been generated".format(num_beams)
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
# Otherwise generate some next word choices
next_sent_beam = []
# add next words for this sentence
for i, (idx, score) in enumerate(zip(next_words[batch_idx], next_scores[batch_idx])):
beam_id = idx // vocab_size
word_id = idx % vocab_size
assert prev_output_tokens.shape[1] == (step + 1)
if word_id.item() == eos_token_id:
if i >= num_beams:
continue
finalized_hyps[batch_idx].add(
prev_output_tokens[batch_idx * num_beams + beam_id].clone(), score.item(),
)
else:
next_sent_beam.append((score, word_id, batch_idx * num_beams + beam_id))
if len(next_sent_beam) == num_beams: # TODO(SS): can we delete this?
break
# Check if were done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or finalized_hyps[batch_idx].is_done(
next_scores[batch_idx].max().item(), cur_len=step + 1,
)
assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
if all(done):
break
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_words = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order decoder inputs to [beam_idx]
prev_output_tokens = prev_output_tokens[beam_idx]
prev_output_tokens = torch.cat([prev_output_tokens, beam_words.unsqueeze(1)], dim=-1)
# re-order internal states
decoder_cache = self._reorder_cache(decoder_cache, beam_idx)
for batch_idx in range(batch_size):
# Add all open beam hypothesis to generated_hyps
if done[batch_idx]:
continue
offset = batch_idx * num_beams
for i in range(num_beams):
score = beam_scores[offset + i]
final_tokens = prev_output_tokens[offset + i]
finalized_hyps[batch_idx].add(final_tokens, score.item())
# select the best hypotheses
sent_lengths = input_ids.new(batch_size)
best = []
for i, hypotheses in enumerate(finalized_hyps):
best_hyp = max(hypotheses.beams, key=lambda x: x[0])[1]
sent_lengths[i] = len(best_hyp)
best.append(best_hyp)
# shorter batches are filled with pad_token
if sent_lengths.min().item() != sent_lengths.max().item():
# TODO(SS): decoded = torch.rnn.utils.pad_sequence(best, batch_first=True, padding_value=pad_token_id)
sent_max_len = min(sent_lengths.max().item() + 1, max_length + 1) # TODO(SS): same as step?
decoded = input_ids.new(batch_size, sent_max_len).fill_(pad_token_id)
# fill with hypothesis and eos_token_id if necessary
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id
else:
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
return decoded[:, 1:] # get rid of starting EOS
@staticmethod
def calc_banned_tokens(prev_output_tokens, num_hypos, no_repeat_ngram_size, step):
"""Copied from fairseq for no_repeat_ngram in beam_search"""
# TODO(SS): this can go on parent if there is demand
if step + 2 < no_repeat_ngram_size:
return [
[] for _ in range(num_hypos)
] # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
gen_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_output_tokens[idx].tolist()
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
k = tuple(ngram[:-1])
gen_ngrams[idx][k] = gen_ngrams[idx].get(k, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
"""Before decoding the next token, prevent decoding of ngrams that have already appeared"""
ngram_index = tuple(prev_output_tokens[hypo_idx, step + 2 - no_repeat_ngram_size : step + 1].tolist())
return gen_ngrams[hypo_idx].get(ngram_index, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
@add_start_docstrings(
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
......
......@@ -171,7 +171,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
else:
output_embeddings.weight = input_embeddings.weight
if hasattr(output_embeddings, "bias") and output_embeddings.bias is not None:
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = torch.nn.functional.pad(
output_embeddings.bias.data,
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
......@@ -558,7 +558,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
model.__class__.__name__, "\n\t".join(error_msgs)
)
)
model.tie_weights() # make sure word embedding weights are still tied if needed
# Set model in evaluation mode to desactivate DropOut modules by default
......@@ -574,16 +573,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return {"input_ids": input_ids}
def _do_output_past(self, outputs):
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
if has_output_past and not has_mem_len and len(outputs) > 1:
return True
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
has_output_past = getattr(self.config, "output_past", False)
mem_len = getattr(self.config, "mem_len", 0)
if len(outputs) <= 1:
return False
if mem_len > 0 or has_output_past:
return True
return False
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
for i in range(batch_size * num_beams):
for previous_token in set(prev_output_tokens[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if lprobs[i, previous_token] < 0:
lprobs[i, previous_token] *= repetition_penalty
else:
lprobs[i, previous_token] /= repetition_penalty
@torch.no_grad()
def generate(
self,
......@@ -761,7 +769,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(
batch_size * num_return_sequences, cur_len
) # (batch_size * num_return_sequences, cur_len)
) # shape: (batch_size * num_return_sequences, cur_len)
effective_batch_size = batch_size * num_return_sequences
else:
effective_batch_size = batch_size
......@@ -822,9 +830,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
......@@ -834,13 +842,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if next_token_logits[i, previous_token] < 0:
next_token_logits[i, previous_token] *= repetition_penalty
else:
next_token_logits[i, previous_token] /= repetition_penalty
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
......@@ -911,6 +913,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
# assert input_ids.shape == (batch_size * num_beams, cur_len)
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
......@@ -941,13 +944,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size * num_beams):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if scores[i, previous_token] < 0:
scores[i, previous_token] *= repetition_penalty
else:
scores[i, previous_token] /= repetition_penalty
self.enforce_repetition_penalty_(scores, batch_size, num_beams, input_ids, repetition_penalty)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
......@@ -1039,16 +1036,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# re-order internal states
if past:
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
past = self._reorder_cache(past, beam_idx)
# update current length
cur_len = cur_len + 1
......@@ -1096,6 +1084,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return decoded
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
return past
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
......@@ -1164,17 +1166,22 @@ class BeamHypotheses(object):
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs):
def is_done(self, best_sum_logprobs, cur_len=None):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
if cur_len is None:
cur_len = self.max_length
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
class Conv1D(nn.Module):
......
......@@ -19,11 +19,7 @@ from .tokenization_roberta import RobertaTokenizer
# vocab and merges same as roberta
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
_all_bart_models = [
"bart-large",
"bart-large-mnli",
# "bart-large-cnn"
]
_all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn"]
class BartTokenizer(RobertaTokenizer):
......
......@@ -240,7 +240,7 @@ class BartHeadTests(unittest.TestCase):
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(logits.shape, expected_shape)
def test_generate(self):
def test_generate_beam_search(self):
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long()
config = BartConfig(
vocab_size=self.vocab_size,
......@@ -256,8 +256,12 @@ class BartHeadTests(unittest.TestCase):
)
lm_model = BartForMaskedLM(config)
lm_model.eval()
new_input_ids = lm_model.generate(input_ids)
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 20))
new_input_ids = lm_model.generate(
input_ids.clone(), num_return_sequences=1, num_beams=2, no_repeat_ngram_size=3, max_length=5
)
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 5))
# TODO(SS): uneven length batches, empty inputs
def test_shift_tokens_right(self):
input_ids = torch.Tensor([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]]).long()
......@@ -352,3 +356,55 @@ class BartModelIntegrationTest(unittest.TestCase):
for model_name in list(BART_PRETRAINED_MODEL_ARCHIVE_MAP.keys()):
model = BartModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
@slow
def test_cnn_summarization_same_as_fairseq(self):
hf = BartForMaskedLM.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
tok = BartTokenizer.from_pretrained("bart-large")
text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian"
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
extra_len = 20
gen_tokens = hf.generate(tokens, num_beams=4, max_length=extra_len,) # repetition_penalty=10.,
expected_result = "<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
generated = [tok.decode(g,) for g in gen_tokens]
self.assertEqual(expected_result, generated[0])
# Harder cases with batching
FRANCE_ARTICLE = ' Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noqa
EXPECTED_SUMMARY_FRANCE = 'French prosecutor says he\'s not aware of any video footage from on board the plane. German daily Bild and French Paris Match claim to have found a cell phone video of the crash. A French Gendarmerie spokesman calls the reports "completely wrong" and "unwarranted" German airline Lufthansa confirms co-pilot Andreas Lubitz had battled depression.'
SHORTER_ARTICLE = ' (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
EXPECTED_SUMMARY_SHORTER = "The Palestinian Authority becomes the 123rd member of the International Criminal Court. The move gives the court jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a move toward greater justice."
# The below article tests that we don't add any hypotheses outside of the top n_beams
IRAN_ARTICLE = " (CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger. Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a letter to the Iranian leadership warning them away from a deal. The debate that has already begun since the announcement of the new framework will likely result in more heat than light. It will not be helped by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: . The most misleading assertion, despite universal rejection by experts, is that the negotiations' objective at the outset was the total elimination of any nuclear program in Iran. That is the position of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it had been, there would have been no Iranian team at the negotiating table. Rather, the objective has always been to structure an agreement or series of agreements so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. The new framework has exceeded expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite sharp accusations by some in the United States and its allies, Iran denies having such a program, and U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's continued cooperation with International Atomic Energy Agency inspections is further evidence on this point, and we'll know even more about Iran's program in the coming months and years because of the deal. In fact, the inspections provisions that are part of this agreement are designed to protect against any covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter warning that a deal might be killed by Congress or a future president). This of course is not the case. The talks were between Iran and the five permanent members of the U.N. Security Council (United States, United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the agreement should be a formal treaty requiring the Senate to \"advise and consent.\" But the issue is not suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement with Iran will not be so balanced. The restrictions and obligations in the final framework agreement will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally some insist that any agreement must address Iranian missile programs, human rights violations or support for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in the negotiations would be a poison pill. This agreement should be judged on its merits and on how it affects the security of our negotiating partners and allies, including Israel. Those judgments should be fact-based, not based on questionable assertions or dubious assumptions."
EXPECTED_SUMMARY_IRAN = "The U.S. and its negotiating partners reached a very strong framework agreement with Iran. Peter Bergen: The debate that has already begun will likely result in more heat than light. He says the agreement limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Bergen says the most important aim of a nuclear deal is preventing a nuclear Iran."
ARTICLE_SUBWAY = ' New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
EXPECTED_SUMMARY_SUBWAY = "Liana Barrientos has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx. She was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the subway."
dct = tok.batch_encode_plus(
[FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
max_length=1024,
pad_to_max_length=True,
return_tensors="pt",
)
self.assertEqual(1024, dct["input_ids"].shape[1])
hypotheses_batch = hf.generate(
input_ids=dct["input_ids"].to(torch_device),
attention_mask=dct["attention_mask"].to(torch_device),
num_beams=4,
length_penalty=2.0,
max_length=140,
min_len=55,
no_repeat_ngram_size=3,
)
decoded = [
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
]
self.assertListEqual(
[EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
decoded,
)
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
# TODO(SS): add test case that hits max_length
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment