"sgl-kernel/vscode:/vscode.git/clone" did not exist on "14e754a868619b5099688d303667d09d2ef3724c"
Unverified Commit 146c5212 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Merge branch 'master' into add_models_special_tokens_to_specific_configs

parents f5b50c6b b623ddc0
......@@ -42,6 +42,7 @@ from .modeling_albert import (
AlbertForMaskedLM,
AlbertForQuestionAnswering,
AlbertForSequenceClassification,
AlbertForTokenClassification,
AlbertModel,
)
from .modeling_bart import BART_PRETRAINED_MODEL_ARCHIVE_MAP, BartForMaskedLM, BartForSequenceClassification, BartModel
......@@ -233,6 +234,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
(RobertaConfig, RobertaForTokenClassification),
(BertConfig, BertForTokenClassification),
(XLNetConfig, XLNetForTokenClassification),
(AlbertConfig, AlbertForTokenClassification),
]
)
......
......@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BART model, ported from the fairseq repo."""
import logging
import math
import random
from typing import Dict, List, Optional, Tuple
......@@ -24,7 +24,7 @@ from torch import Tensor, nn
from .configuration_bart import BartConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
from .modeling_utils import BeamHypotheses, PreTrainedModel, create_position_ids_from_input_ids
logger = logging.getLogger(__name__)
......@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/pytorch_model.bin",
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
}
BART_START_DOCSTRING = r"""
......@@ -86,7 +87,7 @@ def _prepare_bart_decoder_inputs(
causal_lm_mask = None
new_shape = (bsz, tgt_len, tgt_len)
# make it broadcastable so can just be added to the attention coefficients
decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape)
decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape).to(device=input_ids.device)
assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len)
return decoder_input_ids, decoder_attn_mask
......@@ -207,7 +208,7 @@ class EncoderLayer(nn.Module):
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
x, attn_weights = self.self_attn.forward(
x, attn_weights = self.self_attn(
query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
......@@ -291,7 +292,7 @@ class BartEncoder(nn.Module):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
attn = None
else:
x, attn = encoder_layer.forward(x, attention_mask)
x, attn = encoder_layer(x, attention_mask)
if self.output_attentions:
all_attentions.append(attn)
......@@ -332,7 +333,7 @@ class DecoderLayer(nn.Module):
x,
encoder_hidden_states,
encoder_attn_mask=None,
decoder_cached_states=None,
layer_state=None,
attention_mask=None,
need_attn_weights=False,
):
......@@ -348,43 +349,28 @@ class DecoderLayer(nn.Module):
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if decoder_cached_states is None:
prev_self_attn_state, prev_attn_state = (None, None)
else:
assert len(decoder_cached_states) == 3
prev_self_attn_state, prev_attn_state = (
decoder_cached_states["self"],
decoder_cached_states["encoder_decoder"],
)
residual = x
if prev_self_attn_state is not None:
saved_state = prev_self_attn_state
decoder_cached_states["self"] = saved_state
y = x # TODO(SS): figure out why fairseq did this, then hopefully delete it
x, self_attn_weights = self.self_attn.forward(
query=x,
key=y,
value=y,
decoder_cached_states=decoder_cached_states,
need_weights=need_attn_weights,
attn_mask=attention_mask,
if layer_state is None:
layer_state = {}
# next line mutates layer state
x, self_attn_weights = self.self_attn(
query=x, key=y, value=y, layer_state=layer_state, need_weights=need_attn_weights, attn_mask=attention_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
assert self.encoder_attn.cache_key != self.self_attn.cache_key
if prev_attn_state is not None:
saved_state = prev_attn_state
decoder_cached_states["encoder_decoder"] = saved_state
x, encoder_attn_weights = self.encoder_attn.forward(
x, encoder_attn_weights = self.encoder_attn(
query=x,
key=encoder_hidden_states, # could be None
value=encoder_hidden_states,
key_padding_mask=encoder_attn_mask,
decoder_cached_states=decoder_cached_states,
layer_state=layer_state, # mutates layer state
static_kv=True,
need_weights=False, # not returning it so why compute it
)
......@@ -403,15 +389,8 @@ class DecoderLayer(nn.Module):
return (
x,
self_attn_weights,
decoder_cached_states,
) # just self_attn weights for now, following t5, decoder_cached_states = cache for decoding
def _past_to_dict(self, prev_attn_state):
prev_key, prev_value = prev_attn_state[:2]
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
return saved_state
layer_state,
) # just self_attn weights for now, following t5, layer_state = cache for decoding
class BartDecoder(nn.Module):
......@@ -440,6 +419,7 @@ class BartDecoder(nn.Module):
[DecoderLayer(config) for _ in range(config.decoder_layers)]
) # type: List[DecoderLayer]
self.layernorm_embedding = LayerNorm(config.d_model)
self.generation_mode = False
def forward(
self,
......@@ -469,11 +449,15 @@ class BartDecoder(nn.Module):
- attentions
"""
# embed positions
positions = self.embed_positions(input_ids)
x = self.embed_tokens(input_ids)
positions = self.embed_positions(input_ids, generation_mode=self.generation_mode)
if self.generation_mode:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them
assert input_ids.ne(self.padding_idx).any()
if positions is not None:
x += positions
x = self.embed_tokens(input_ids)
x += positions
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
......@@ -489,17 +473,19 @@ class BartDecoder(nn.Module):
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
x, layer_self_attn, layer_past = decoder_layer.forward(
x, layer_self_attn, layer_past = decoder_layer(
x,
encoder_hidden_states,
encoder_padding_mask,
decoder_cached_states=layer_state,
layer_state=layer_state,
attention_mask=combined_mask,
need_attn_weights=self.output_attentions,
)
if self.output_past:
next_decoder_cache.append(layer_past)
next_decoder_cache.append(layer_past.copy())
if self.output_hidden_states:
all_hidden_states += (x,)
if self.output_attentions:
......@@ -509,7 +495,22 @@ class BartDecoder(nn.Module):
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
x = x.transpose(0, 1)
return x, next_decoder_cache, all_hidden_states, list(all_self_attns)
if self.output_past:
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
else:
next_cache = None
return x, next_cache, all_hidden_states, list(all_self_attns)
def reorder_attn_buffer(input_buffer, new_order):
"""Reorder buffered internal state (for incremental generation)."""
# input_buffer = self._get_input_buffer(incremental_state)
for k in input_buffer.keys():
input_buffer_k = input_buffer[k]
if input_buffer_k is not None:
input_buffer[k] = input_buffer_k.index_select(0, new_order)
# incremental_state = self._set_input_buffer(incremental_state, input_buffer)
return input_buffer
class SelfAttention(nn.Module):
......@@ -557,7 +558,7 @@ class SelfAttention(nn.Module):
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
decoder_cached_states: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
layer_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = False,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
......@@ -579,8 +580,8 @@ class SelfAttention(nn.Module):
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
# get here for encoder decoder cause of static_kv
if decoder_cached_states is not None: # get the last k,v and mask for reuse
saved_state = decoder_cached_states.get(self.cache_key, {})
if layer_state is not None: # get the last k,v and mask for reuse
saved_state = layer_state.get(self.cache_key, {})
if "prev_key" in saved_state:
# previous time steps are cached - no need to recompute key and value if they are static
if static_kv:
......@@ -588,6 +589,7 @@ class SelfAttention(nn.Module):
key = value = None
else:
saved_state = None
layer_state = {}
q = self.q_proj(query) * self.scaling
if self.encoder_decoder_attention:
......@@ -608,17 +610,16 @@ class SelfAttention(nn.Module):
v = self._shape(v, -1, bsz)
if saved_state is not None:
k, v, key_padding_mask, new_state = self._use_and_update_saved_state(
k, v, saved_state, key_padding_mask, static_kv, bsz
)
saved_state.update(
{
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
"prev_key_padding_mask": key_padding_mask,
}
)
decoder_cached_states[self.cache_key] = saved_state # Update cache
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
# assert self.cache_key != 'encoder_decoder' or key_padding_mask is None
# Update cache
layer_state[self.cache_key] = {
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
"prev_key_padding_mask": key_padding_mask if not static_kv else None,
}
assert k is not None
src_len = k.size(1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
......@@ -632,16 +633,16 @@ class SelfAttention(nn.Module):
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len)
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,)
if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool)
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = F.dropout(attn_weights_float, p=self.dropout, training=self.training,)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,)
assert v is not None
attn_output = torch.bmm(attn_probs, v)
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
......@@ -650,7 +651,7 @@ class SelfAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
return attn_output, attn_weights
def _use_and_update_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
......@@ -675,7 +676,7 @@ class SelfAttention(nn.Module):
key_padding_mask = self._cat_prev_key_padding_mask(
key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv
)
return k, v, key_padding_mask, saved_state
return k, v, key_padding_mask
@staticmethod
def _cat_prev_key_padding_mask(
......@@ -693,10 +694,9 @@ class SelfAttention(nn.Module):
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current is None
elif prev_key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
if prev_key_padding_mask.is_cuda:
filler = filler.cuda()
filler = filler.to(prev_key_padding_mask.device)
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
elif key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
......@@ -747,9 +747,13 @@ class LearnedPositionalEmbedding(nn.Embedding):
num_embeddings += padding_idx + 1 # WHY?
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def forward(self, input):
def forward(self, input, generation_mode=False):
"""Input is expected to be of size [bsz x seqlen]."""
positions = create_position_ids_from_input_ids(input, self.padding_idx)
if generation_mode: # the position is our current step in the decoded sequence
pos = int(self.padding_idx + input.size(1))
positions = input.data.new(1, 1).fill_(pos)
else:
positions = create_position_ids_from_input_ids(input, self.padding_idx)
return super().forward(positions)
......@@ -826,21 +830,20 @@ class BartModel(PretrainedBartModel):
assert attention_mask.max() <= 0
# make masks if user doesn't supply
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(
self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_attn_mask=decoder_attention_mask,
)
if not self.decoder.generation_mode:
decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs(
self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_attn_mask=decoder_attention_mask,
)
assert decoder_input_ids is not None
if encoder_outputs is None:
# TODO(SS): make this caching more usable when overwrite generate
encoder_outputs = self.encoder.forward(input_ids=input_ids, attention_mask=attention_mask)
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
assert isinstance(encoder_outputs, tuple)
# dec_features, decoder_cached_states, dec_hidden, dec_attn
decoder_outputs = self.decoder.forward(
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
decoder_input_ids,
encoder_outputs[0],
attention_mask,
decoder_attn_mask,
decoder_attention_mask,
decoder_cached_states=decoder_cached_states,
)
# Attention and hidden_states will be [] or None if they aren't needed
......@@ -856,20 +859,26 @@ class BartModel(PretrainedBartModel):
self.shared = value
def get_output_embeddings(self):
return _make_linear_from_emb(self.shared)
return _make_linear_from_emb(self.shared) # make it on the fly
@add_start_docstrings(
"The bare BART Model with a language modeling head", BART_START_DOCSTRING,
"The bare BART Model with a language modeling head. This is the model used for summarization.",
BART_START_DOCSTRING,
)
class BartForMaskedLM(PretrainedBartModel):
base_model_prefix = "model"
def __init__(self, config: BartConfig):
super().__init__(config)
self.model = BartModel(config)
# if base_model is None:
base_model = BartModel(config)
self.model = base_model
self.lm_head = _make_linear_from_emb(self.model.shared)
def tie_weights(self):
pass # hack to prevent changing lm_head.out_features. The input and output embeddings are still the same.
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
def forward(
self,
......@@ -916,7 +925,7 @@ class BartForMaskedLM(PretrainedBartModel):
outputs = model(input_ids=input_ids, lm_labels=input_ids)
loss, prediction_scores = outputs[:2]
"""
outputs = self.model.forward(
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
......@@ -924,7 +933,7 @@ class BartForMaskedLM(PretrainedBartModel):
decoder_attention_mask=decoder_attention_mask,
decoder_cached_states=decoder_cached_states,
)
lm_logits = self.lm_head.forward(outputs[0])
lm_logits = self.lm_head(outputs[0])
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
if lm_labels is not None:
loss_fct = nn.CrossEntropyLoss()
......@@ -935,12 +944,309 @@ class BartForMaskedLM(PretrainedBartModel):
return outputs
@staticmethod
def prepare_inputs_for_generation(input_ids, past, **kwargs):
return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids[:, -1:]}
def prepare_inputs_for_generation(input_ids, past, decoder_input_ids, attention_mask):
if past is None: # first step
encoder_outputs, decoder_cached_states = None, None
else:
encoder_outputs, decoder_cached_states = past
return {
"input_ids": input_ids, # ignored after first pass
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
# "decoder_attention_mask": decoder_attention_mask,
}
@staticmethod
def _reorder_cache(past, beam_idx):
((enc_out, enc_mask), decoder_cached_states) = past
reordered_past = []
for layer_past in decoder_cached_states:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {
attn_key: reorder_attn_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
}
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
reordered_past.append(layer_past_new)
new_enc_out = enc_out if enc_out is None else enc_out.index_select(1, beam_idx)
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
past = ((new_enc_out, new_enc_mask), reordered_past)
return past
def get_output_embeddings(self):
return self.lm_head
@torch.no_grad()
def generate(
self,
input_ids,
attention_mask=None,
max_length=20,
num_beams=1,
repetition_penalty=1.0,
length_penalty=1.0,
num_return_sequences=1,
min_len=0,
no_repeat_ngram_size=0,
):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
Adapted in part from Facebook's `XLM beam search code`_ and `Fairseq beam search code`_.
.. _`XLM beam search code`:
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
.. _`Fairseq beam search code`:
https://github.com/pytorch/fairseq/blob/master/fairseq/sequence_generator.py
Parameters:
input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape `(1,)`.
max_length: (`optional`) int
The max length of the sequence to be generated. Does not include tokens in input_ids.
num_beams: (`optional`) int
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
repetition_penalty: (`optional`) float
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
length_penalty: (`optional`) float
Exponential penalty to the length. Default to 1.
num_return_sequences: (`optional`) int
The number of independently computed returned sequences for each element in the batch. Default to 1.
min_len: (`optional`) int
Returns:
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
sequence_length is <= max_length (examples can finish early)
Examples::
config = BartConfig(vocab_size=50264, output_past=True)
model = AutoModelWithLMHead.from_pretrained('bart-large-cnn', config=config)
tokenizer = AutoTokenizer.from_pretrained('bart-large-cnn')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary
generated_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in generated_ids])
"""
bos_token_id = self.config.bos_token_id
pad_token_id = self.config.pad_token_id
eos_token_id = self.config.eos_token_id
batch_size, cur_len = input_ids.shape
assert input_ids is not None
assert self.config.output_past, "Generating with bart requires instantiating a config with output_past=True"
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
assert isinstance(pad_token_id, int)
assert bos_token_id == 0, "configurable bos_token_id not yet supported"
assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0
), "`num_return_sequences` should be a positive integer."
# current position and vocab size
cur_len = input_ids.shape[1]
vocab_size = self.config.vocab_size
if num_return_sequences != 1:
# Expand input to num return sequences
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(
batch_size * num_return_sequences, cur_len
) # shape: (batch_size * num_return_sequences, cur_len)
batch_size *= num_return_sequences
# Below here somewhat similar to PretrainedModel._generate_beam_search
# Expand input to num beams
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
if attention_mask is not None:
attention_mask = (
attention_mask.unsqueeze(1)
.expand(batch_size, num_beams, cur_len)
.contiguous()
.view(batch_size * num_beams, cur_len)
) # RESHAPE
# generated hypotheses
finalized_hyps = [ # they end in EOS and we wont work on them more!
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=True) for _ in range(batch_size)
]
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9 # avoid ties in first step
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# decoder tokens
prev_output_tokens = input_ids.new(batch_size * num_beams, 1).long().fill_(-1)
prev_output_tokens[:, 0] = 2 # HARDCODED EOS, which will be removed at the end.
decoder_cache = None
done = [False for _ in range(batch_size)] # done sentences
self.model.decoder.generation_mode = True # tells decoder not to use causal mask
for step in range(max_length + 1):
decoder_input_ids = prev_output_tokens.clone()
model_inputs = self.prepare_inputs_for_generation(
input_ids, decoder_cache, decoder_input_ids, attention_mask,
)
outputs = self(**model_inputs)
lprobs = F.log_softmax(outputs[0][:, -1, :], dim=-1)
lprobs[lprobs != lprobs] = -math.inf # block nans
lprobs[:, pad_token_id] = -math.inf
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
if step == 0: # Force BOS to be chosen
lprobs[:, bos_token_id + 1 :] = -math.inf
elif step < min_len: # Prevent EOS from being chosen
lprobs[:, eos_token_id] = -math.inf
elif step == max_length: # FORCE EOS to be chosen
lprobs[:, :eos_token_id] = -math.inf
lprobs[:, eos_token_id + 1 :] = -math.inf
assert self._do_output_past(outputs)
decoder_cache = outputs[1]
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty)
num_hypos = batch_size * num_beams
if no_repeat_ngram_size > 0: # copied from fairseq
# for each sentence, calculate a list of banned tokens to prevent repetitively generating the same ngrams
banned_tokens = self.calc_banned_tokens(prev_output_tokens, num_hypos, no_repeat_ngram_size, step)
# then set their probabilities tof -inf
for idx in range(num_hypos):
lprobs[idx, banned_tokens[idx]] = -math.inf
assert lprobs.size() == (batch_size * num_beams, vocab_size)
_scores = lprobs + beam_scores[:, None].expand_as(lprobs) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis across beams)
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
# Take the best 2 x beam_size predictions for each example, we'll choose the first beam_size of these which don't predict eos to continue with.
next_scores, next_words = torch.topk(_scores, 2 * num_beams)
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
# list of (batch_size * num_beams)
next_batch_beam = [] # Tuple(next score, next word, current position in the batch)
for batch_idx in range(batch_size):
# if we are done with this sentence (because we can't improve)
if done[batch_idx]: # then pad all associated hypotheses
assert (
len(finalized_hyps[batch_idx]) >= num_beams
), "Example can only be done if at least {} beams have been generated".format(num_beams)
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
# Otherwise generate some next word choices
next_sent_beam = []
# add next words for this sentence
for i, (idx, score) in enumerate(zip(next_words[batch_idx], next_scores[batch_idx])):
beam_id = idx // vocab_size
word_id = idx % vocab_size
assert prev_output_tokens.shape[1] == (step + 1)
if word_id.item() == eos_token_id:
if i >= num_beams:
continue
finalized_hyps[batch_idx].add(
prev_output_tokens[batch_idx * num_beams + beam_id].clone(), score.item(),
)
else:
next_sent_beam.append((score, word_id, batch_idx * num_beams + beam_id))
if len(next_sent_beam) == num_beams: # TODO(SS): can we delete this?
break
# Check if were done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or finalized_hyps[batch_idx].is_done(
next_scores[batch_idx].max().item(), cur_len=step + 1,
)
assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
if all(done):
break
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_words = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order decoder inputs to [beam_idx]
prev_output_tokens = prev_output_tokens[beam_idx]
prev_output_tokens = torch.cat([prev_output_tokens, beam_words.unsqueeze(1)], dim=-1)
# re-order internal states
decoder_cache = self._reorder_cache(decoder_cache, beam_idx)
for batch_idx in range(batch_size):
# Add all open beam hypothesis to generated_hyps
if done[batch_idx]:
continue
offset = batch_idx * num_beams
for i in range(num_beams):
score = beam_scores[offset + i]
final_tokens = prev_output_tokens[offset + i]
finalized_hyps[batch_idx].add(final_tokens, score.item())
# select the best hypotheses
sent_lengths = input_ids.new(batch_size)
best = []
for i, hypotheses in enumerate(finalized_hyps):
best_hyp = max(hypotheses.beams, key=lambda x: x[0])[1]
sent_lengths[i] = len(best_hyp)
best.append(best_hyp)
# shorter batches are filled with pad_token
if sent_lengths.min().item() != sent_lengths.max().item():
# TODO(SS): decoded = torch.rnn.utils.pad_sequence(best, batch_first=True, padding_value=pad_token_id)
sent_max_len = min(sent_lengths.max().item() + 1, max_length + 1) # TODO(SS): same as step?
decoded = input_ids.new(batch_size, sent_max_len).fill_(pad_token_id)
# fill with hypothesis and eos_token_id if necessary
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id
else:
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
return decoded[:, 1:] # get rid of starting EOS
@staticmethod
def calc_banned_tokens(prev_output_tokens, num_hypos, no_repeat_ngram_size, step):
"""Copied from fairseq for no_repeat_ngram in beam_search"""
# TODO(SS): this can go on parent if there is demand
if step + 2 < no_repeat_ngram_size:
return [
[] for _ in range(num_hypos)
] # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
gen_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_output_tokens[idx].tolist()
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
k = tuple(ngram[:-1])
gen_ngrams[idx][k] = gen_ngrams[idx].get(k, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
"""Before decoding the next token, prevent decoding of ngrams that have already appeared"""
ngram_index = tuple(prev_output_tokens[hypo_idx, step + 2 - no_repeat_ngram_size : step + 1].tolist())
return gen_ngrams[hypo_idx].get(ngram_index, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
@add_start_docstrings(
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
......@@ -1002,7 +1308,7 @@ class BartForSequenceClassification(PretrainedBartModel):
loss, logits = outputs[:2]
"""
outputs = self.model.forward(
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
......@@ -1018,7 +1324,7 @@ class BartForSequenceClassification(PretrainedBartModel):
# Prepend logits
outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here
if labels is not None: # prepend loss to output,
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs
......@@ -1230,7 +1230,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor`` of shape ``(1,)`, `optional`, returned when :obj:`labels` is provided):
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
......@@ -1382,8 +1382,10 @@ class BertForTokenClassification(BertPreTrainedModel):
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
......
......@@ -454,14 +454,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if "past" in kwargs and kwargs["past"]:
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
inputs = {"input_ids": input_ids}
inputs.update(kwargs)
return inputs
return {"input_ids": input_ids, "past": past}
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
def forward(
......
......@@ -818,8 +818,10 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
......
......@@ -234,62 +234,3 @@ class PreTrainedEncoderDecoder(nn.Module):
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
return decoder_outputs + encoder_outputs
class Model2Model(PreTrainedEncoderDecoder):
r"""
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
where both of the encoder and decoder are of the same family. If the
name of or that path to a pretrained model is specified the encoder and
the decoder will be initialized with the pretrained weight (the
cross-attention will be intialized randomly if its weights are not
present).
It is possible to override this behavior and initialize, say, the decoder randomly
by creating it beforehand as follows
config = BertConfig.from_pretrained()
decoder = BertForMaskedLM(config)
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tie_weights()
def tie_weights(self):
""" Tying the encoder and decoders' embeddings together.
We need for each to get down to the embedding weights. However the
different model classes are inconsistent to that respect:
- BertModel: embeddings.word_embeddings
- RoBERTa: embeddings.word_embeddings
- XLMModel: embeddings
- GPT2: wte
- BertForMaskedLM: bert.embeddings.word_embeddings
- RobertaForMaskedLM: roberta.embeddings.word_embeddings
argument of the XEmbedding layer for each model, but it is "blocked"
by a model-specific keyword (bert, )...
"""
# self._tie_or_clone_weights(self.encoder, self.decoder)
pass
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
if (
"bert" not in pretrained_model_name_or_path
or "roberta" in pretrained_model_name_or_path
or "distilbert" in pretrained_model_name_or_path
):
raise ValueError("Only the Bert model is currently supported.")
model = super().from_pretrained(
encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
*args,
**kwargs,
)
return model
......@@ -148,9 +148,12 @@ class FlaubertModel(XLMModel):
Examples::
from transformers import FlaubertTokenizer, FlaubertModel
import torch
tokenizer = FlaubertTokenizer.from_pretrained('flaubert-base-cased')
model = FlaubertModel.from_pretrained('flaubert-base-cased')
input_ids = torch.tensor(tokenizer.encode("Le chat manges une pomme.", add_special_tokens=True)).unsqueeze(0) # Batch size 1
input_ids = torch.tensor(tokenizer.encode("Le chat mange une pomme.", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
......
......@@ -276,14 +276,17 @@ GPT2_START_DOCSTRING = r"""
GPT2_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
`input_ids_length` = `sequence_length if `past` is None else 1
Indices of input sequence tokens in the vocabulary.
If using `past` as an input make sure that `input_ids` are those of the last position.
Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
......@@ -294,10 +297,12 @@ GPT2_INPUTS_DOCSTRING = r"""
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`, defaults to :obj:`None`):
`input_ids_length` = `sequence_length if `past` is None else 1
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
If using `past` as an input make sure that `token_type_ids` correspond to the `input_ids` of the last position.
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
......@@ -419,7 +424,8 @@ class GPT2Model(GPT2PreTrainedModel):
# Attention mask.
if attention_mask is not None:
attention_mask = attention_mask.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
......@@ -519,14 +525,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if "past" in kwargs and kwargs["past"]:
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
inputs = {"input_ids": input_ids}
inputs.update(kwargs)
return inputs
return {"input_ids": input_ids, "past": past}
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
def forward(
......
......@@ -542,13 +542,16 @@ class RobertaForTokenClassification(BertPreTrainedModel):
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
......
......@@ -480,7 +480,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
class TFAlbertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
......
......@@ -668,38 +668,39 @@ class TFBertModel(TFBertPreTrainedModel):
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
r"""
Returns:
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during Bert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import BertTokenizer, TFBertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertModel.from_pretrained('bert-base-uncased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during Bert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import BertTokenizer, TFBertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertModel.from_pretrained('bert-base-uncased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
outputs = self.bert(inputs, **kwargs)
return outputs
......
......@@ -104,10 +104,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
k = self.split_into_heads(k, batch_size)
v = self.split_into_heads(v, batch_size)
if layer_past is not None:
past_key, past_value = tf.unstack(layer_past, axis=1)
k = tf.concat((past_key, k), dim=-2)
v = tf.concat((past_value, v), dim=-2)
present = tf.stack((k, v), axis=1)
past_key, past_value = tf.unstack(layer_past, axis=0)
k = tf.concat((past_key, k), axis=-2)
v = tf.concat((past_value, v), axis=-2)
present = tf.stack((k, v), axis=0)
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])
......@@ -505,6 +505,13 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head.input_embeddings
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {"inputs": inputs, "past": past}
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
r"""
......
......@@ -139,10 +139,10 @@ class TFAttention(tf.keras.layers.Layer):
key = self.split_heads(key)
value = self.split_heads(value)
if layer_past is not None:
past_key, past_value = tf.unstack(layer_past, axis=1)
past_key, past_value = tf.unstack(layer_past, axis=0)
key = tf.concat([past_key, key], axis=-2)
value = tf.concat([past_value, value], axis=-2)
present = tf.stack([key, value], axis=1)
present = tf.stack([key, value], axis=0)
attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
a = attn_outputs[0]
......@@ -500,6 +500,13 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
def get_output_embeddings(self):
return self.transformer.wte
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {"inputs": inputs, "past": past}
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
r"""
......
......@@ -199,7 +199,7 @@ class TFBlock(tf.keras.layers.Layer):
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
super().__init__(*inputs, **kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.num_hidden_layers = config.n_layer
......
......@@ -826,3 +826,12 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
outputs = [softmax_output] + outputs
return outputs # logits, new_mems, (all hidden states), (all attentions)
def prepare_inputs_for_generation(self, inputs, past, **model_kwargs):
inputs = {"inputs": inputs}
# if past is defined in model kwargs then use it for faster decoding
if past:
inputs["mems"] = past
return inputs
......@@ -142,7 +142,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# # initialize all new embeddings (in particular added tokens)
# self._init_weights(new_embeddings)
# # Copy word embeddings from the previous weights
# # Copy token embeddings from the previous weights
# num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
# new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
......@@ -384,6 +384,724 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
return model
def prepare_inputs_for_generation(self, inputs, **kwargs):
return {"inputs": inputs}
def _do_output_past(self, outputs):
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
if has_output_past and not has_mem_len and len(outputs) > 1:
return True
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
return True
return False
def generate(
self,
input_ids=None,
max_length=None,
do_sample=True,
num_beams=None,
temperature=None,
top_k=None,
top_p=None,
repetition_penalty=None,
bos_token_id=None,
pad_token_id=None,
eos_token_ids=None,
length_penalty=None,
num_return_sequences=None,
):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
Adapted in part from `Facebook's XLM beam search code`_.
.. _`Facebook's XLM beam search code`:
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
Parameters:
input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape `(1,)`.
max_length: (`optional`) int
The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
do_sample: (`optional`) bool
If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `True`.
num_beams: (`optional`) int
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
temperature: (`optional`) float
The value used to module the next token probabilities. Must be strictely positive. Default to 1.0.
top_k: (`optional`) int
The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
top_p: (`optional`) float
The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
repetition_penalty: (`optional`) float
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
bos_token_id: (`optional`) int
Beginning of sentence token if no prompt is provided. Default to 0.
eos_token_ids: (`optional`) int or list of int
End of sequence token or list of tokens to stop the generation. Default to 0.
length_penalty: (`optional`) float
Exponential penalty to the length. Default to 1.
num_return_sequences: (`optional`) int
The number of independently computed returned sequences for each element in the batch. Default to 1.
Return:
output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
Examples::
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, do_sample=False) # do greedy decoding
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
for i in range(3): # 3 output sequences were generated
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.pad_token_id, eos_token_ids=tokenizer.eos_token_id, num_return_sequences=3) # 3 generate sequences using by sampling
for i in range(3): # 3 output sequences were generated
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
"""
# We cannot generate if the model does not have a LM head
if self.get_output_embeddings() is None:
raise AttributeError(
"You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`)"
)
max_length = max_length if max_length is not None else self.config.max_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
if input_ids is not None:
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
else:
batch_size = 1
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
assert temperature > 0, "`temperature` should be strictely positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
assert input_ids is not None or (
isinstance(bos_token_id, int) and bos_token_id >= 0
), "If input_ids is not defined, `bos_token_id` should be a positive integer."
assert pad_token_id is None or (
isinstance(pad_token_id, int) and (pad_token_id >= 0)
), "`pad_token_id` should be a positive integer."
assert (eos_token_ids is None) or (
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert length_penalty > 0, "`length_penalty` should be strictely positive."
assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0
), "`num_return_sequences` should be a strictely positive integer."
if input_ids is None:
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
"you should either supply a context to complete as `input_ids` input "
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
)
input_ids = tf.fill((batch_size, 1), bos_token_id)
else:
assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."
if do_sample is False:
if num_beams == 1:
# no_beam_search greedy generation conditions
assert (
num_return_sequences == 1
), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
else:
# beam_search greedy generation conditions
assert (
num_beams >= num_return_sequences
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
if pad_token_id is None and eos_token_ids is not None:
logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
)
pad_token_id = eos_token_ids[0]
# current position and vocab size
cur_len = shape_list(input_ids)[1]
vocab_size = self.config.vocab_size
if num_return_sequences != 1 and do_sample:
# Expand input to num return sequences
input_ids = tf.broadcast_to(tf.expand_dims(input_ids, 1), (batch_size, num_return_sequences, cur_len))
effective_batch_size = batch_size * num_return_sequences
input_ids = tf.reshape(input_ids, (effective_batch_size, cur_len))
else:
effective_batch_size = batch_size
if num_beams > 1:
output = self._generate_beam_search(
input_ids,
cur_len,
max_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
pad_token_id,
eos_token_ids,
effective_batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
)
else:
output = self._generate_no_beam_search(
input_ids,
cur_len,
max_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
pad_token_id,
eos_token_ids,
effective_batch_size,
)
return output
def _generate_no_beam_search(
self,
input_ids,
cur_len,
max_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
pad_token_id,
eos_token_ids,
batch_size,
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
# length of generated sentences / unfinished sentences
unfinished_sents = tf.ones_like(input_ids[:, 0])
sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
past = outputs[1]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
next_token_logits_penalties = _create_next_token_logits_penalties(
input_ids, next_token_logits, repetition_penalty
)
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
# Top-p/top-k filtering
next_token_logits = tf_top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
# Sample
next_token = tf.squeeze(
tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=1), axis=1
)
else:
# Greedy decoding
next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)
# update generations and finished sentences
if eos_token_ids is not None:
# pad finished sentences if eos_token_ids exist
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
else:
tokens_to_add = next_token
input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
if eos_token_ids is not None:
for eos_token_id in eos_token_ids:
eos_in_sents = tokens_to_add == eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
unfinished_sents, tf.cast(eos_in_sents, tf.int32)
)
sent_lengths = (
sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
+ cur_len * is_sents_unfinished_and_token_to_add_is_eos
)
# unfinished_sents is set to zero if eos in sentence
unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
cur_len = cur_len + 1
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if tf.math.reduce_max(unfinished_sents) == 0:
break
# if there are different sentences lengths in the batch, some batches have to be padded
min_sent_length = tf.math.reduce_min(sent_lengths)
max_sent_length = tf.math.reduce_max(sent_lengths)
if min_sent_length != max_sent_length:
assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
# finished sents are filled with pad_token
padding = tf.ones([batch_size, max_sent_length.numpy()], dtype=tf.int32) * pad_token_id
# create length masks for tf.where operation
broad_casted_sent_lengths = tf.broadcast_to(
tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length]
)
broad_casted_range = tf.transpose(
tf.broadcast_to(tf.expand_dims(tf.range(max_length), -1), [max_length, batch_size])
)
decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding)
else:
decoded = input_ids
return decoded
def _generate_beam_search(
self,
input_ids,
cur_len,
max_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
pad_token_id,
eos_token_ids,
batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
):
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
input_ids = tf.broadcast_to(tf.expand_dims(input_ids, 1), (batch_size, num_beams, cur_len))
input_ids = tf.reshape(input_ids, (batch_size * num_beams, cur_len)) # (batch_size * num_beams, cur_len)
# generated hypotheses
generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
]
# scores for each sentence in the beam
if do_sample is False:
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
beam_scores = tf.concat([beam_scores_begin, beam_scores_end], -1)
else:
beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
# cache compute states
past = None
# done sentences
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
past = outputs[1]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
next_token_logits_penalties = _create_next_token_logits_penalties(
input_ids, next_token_logits, repetition_penalty
)
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
_scores = scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, vocab_size)
) # (batch_size * num_beams, vocab_size)
# Top-p/top-k filtering
_scores = tf_top_k_top_p_filtering(
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
_scores = tf.reshape(_scores, (batch_size, num_beams * vocab_size))
next_tokens = tf.random.categorical(
_scores, dtype=tf.int32, num_samples=2 * num_beams
) # (batch_size, 2 * num_beams)
# Compute next scores
next_scores = tf.gather(_scores, next_tokens, batch_dims=1) # (batch_size, 2 * num_beams)
else:
# do greedy beam search
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
next_scores = scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, vocab_size)
) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores = tf.reshape(
next_scores, (batch_size, num_beams * vocab_size)
) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True)
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
# next batch beam content
# list of (batch_size * num_beams) tuple(next hypothesis score, next token, current position in the batch)
next_batch_beam = []
# for each sentence
for batch_idx in range(batch_size):
# if we are done with this sentence
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
tf.reduce_max(next_scores[batch_idx]).numpy()
)
if done[batch_idx]:
assert (
len(generated_hyps[batch_idx]) >= num_beams
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
assert (
eos_token_ids is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
# next sentence beam content
next_sent_beam = []
# next tokens for this sentence
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
# get beam and token IDs
beam_id = idx // vocab_size
token_id = idx % vocab_size
# add to generated hypotheses if end of sentence or last iteration
if eos_token_ids is not None and token_id.numpy() in eos_token_ids:
generated_hyps[batch_idx].add(
tf.identity(input_ids[batch_idx * num_beams + beam_id, :cur_len]), score.numpy()
)
else:
# add next predicted token if it is not eos_token
next_sent_beam.append((score, token_id, batch_idx * num_beams + beam_id))
# the beam for next step is full
if len(next_sent_beam) == num_beams:
break
# update next beam content
assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
beam_scores = tf.convert_to_tensor([x[0] for x in next_batch_beam], dtype=tf.float32)
beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32)
beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32)
# re-order batch
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
# re-order internal states
if past:
past = self._reorder_cache(past, beam_idx)
# update current length
cur_len = cur_len + 1
# stop when we are done with each sentence
if all(done):
break
for batch_idx in range(batch_size):
# Add all open beam hypothesis to generated_hyps
if not done[batch_idx]:
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
# get beam and token IDs
beam_id = idx // vocab_size
token_id = idx % vocab_size
generated_hyps[batch_idx].add(
tf.identity(input_ids[batch_idx * num_beams + beam_id, :cur_len]), score.numpy()
)
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
# select the best hypotheses
sent_lengths_list = []
best = []
# retrieve best hypotheses
for i, hypotheses in enumerate(generated_hyps):
sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
for j in range(output_num_return_sequences_per_batch):
best_hyp = sorted_hyps.pop()[1]
sent_lengths_list.append(len(best_hyp))
best.append(best_hyp)
assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(
output_batch_size, len(best)
)
sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32)
# shorter batches are filled with pad_token
if tf.reduce_min(sent_lengths).numpy() != tf.reduce_max(sent_lengths).numpy():
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
sent_max_len = min(tf.reduce_max(sent_lengths).numpy() + 1, max_length)
decoded_list = []
# fill with hypothesis and eos_token_id if necessary
for i, hypo in enumerate(best):
padding = tf.ones((sent_max_len - shape_list(hypo)[0],), dtype=tf.int32) * pad_token_id
decoded_hypo = tf.concat([hypo, padding], axis=0)
if sent_lengths[i] < max_length:
decoded_hypo = tf.where(
tf.range(max_length) == sent_lengths[i],
eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32),
decoded_hypo,
)
decoded_list.append(decoded_hypo)
decoded = tf.stack(decoded_list)
else:
# none of the hypotheses have an eos_token
assert (len(hypo) == max_length for hypo in best)
decoded = tf.stack(best)
return decoded
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
# check that shape matches
assert shape_list(reordered_layer_past) == shape_list(layer_past)
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
return past
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
# create logit penalties for already seen input_ids
token_penalties = np.ones(shape_list(logits))
prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
for i, prev_input_id in enumerate(prev_input_ids):
logit_penalized = logits[i].numpy()[prev_input_id]
# if previous logit score is < 0 then multiply repetition penalty else divide
logit_penalized[logit_penalized < 0] = repetition_penalty
logit_penalized[logit_penalized > 0] = 1 / repetition_penalty
np.put(token_penalties[i], prev_input_id, logit_penalized)
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
logits_shape = shape_list(logits)
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
if top_p < 1.0:
sorted_indices = tf.argsort(logits, direction="DESCENDING")
sorted_logits = tf.gather(
logits, sorted_indices, axis=-1, batch_dims=1
) # expects logits to be of dim (batch_size, vocab_size)
cumulative_probs = tf.math.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove = tf.concat(
[
tf.zeros_like(sorted_indices_to_remove[:, :min_tokens_to_keep]),
sorted_indices_to_remove[:, min_tokens_to_keep:],
],
-1,
)
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove = tf.concat(
[tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]], -1,
)
# scatter sorted tensors to original indexing
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
return logits
def scatter_values_on_batch_indices(values, batch_indices):
shape = shape_list(batch_indices)
# broadcast batch dim to shape
broad_casted_batch_dims = tf.reshape(tf.broadcast_to(tf.expand_dims(tf.range(shape[0]), axis=-1), shape), [1, -1])
# transform batch_indices to pair_indices
pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0))
# scatter values to pair indices
return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), shape)
def set_tensor_by_indices_to_value(tensor, indices, value):
# create value_tensor since tensor value assignment is not possible in TF
value_tensor = tf.zeros_like(tensor) + value
return tf.where(indices, value_tensor, tensor)
class BeamHypotheses(object):
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(self, hyp, sum_logprobs):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
del self.beams[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len=None):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
if cur_len is None:
cur_len = self.max_length
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
class TFConv1D(tf.keras.layers.Layer):
def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
......@@ -423,7 +1141,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
def build(self, input_shape):
"""Build shared word embedding layer
"""Build shared token embedding layer
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
......
......@@ -657,6 +657,20 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
def get_output_embeddings(self):
return self.pred_layer.input_embeddings
def prepare_inputs_for_generation(self, inputs, **kwargs):
mask_token_id = self.config.mask_token_id
lang_id = self.config.lang_id
effective_batch_size = inputs.shape[0]
mask_token = tf.ones((effective_batch_size, 1), dtype=tf.int32) * mask_token_id
inputs = tf.concat([inputs, mask_token], axis=1)
if lang_id is not None:
langs = tf.ones_like(inputs) * lang_id
else:
langs = None
return {"inputs": inputs, "langs": langs}
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
r"""
......
......@@ -837,6 +837,32 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
def get_output_embeddings(self):
return self.lm_loss.input_embeddings
def prepare_inputs_for_generation(self, inputs, past, **model_kwargs):
# Add dummy token at the end (no attention on this one)
effective_batch_size = inputs.shape[0]
dummy_token = tf.zeros((effective_batch_size, 1), dtype=tf.int32)
inputs = tf.concat([inputs, dummy_token], axis=1)
# Build permutation mask so that previous tokens don't see last token
sequence_length = inputs.shape[1]
perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1), dtype=tf.float32)
perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1), dtype=tf.float32)
perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1)
# We'll only predict the last token
target_mapping = tf.zeros((effective_batch_size, 1, sequence_length - 1), dtype=tf.float32)
target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1), dtype=tf.float32)
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
inputs = {"inputs": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping}
# if past is defined in model kwargs then use it for faster decoding
if past:
inputs["mems"] = past
return inputs
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
r"""
......
......@@ -935,11 +935,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
else:
return self.crit.out_layers[-1]
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs):
inputs = {"input_ids": input_ids}
# if past is defined in model kwargs then use it for faster decoding
if "past" in model_kwargs and model_kwargs["past"]:
inputs["mems"] = model_kwargs["past"]
if past:
inputs["mems"] = past
return inputs
......@@ -15,7 +15,6 @@
# limitations under the License.
"""PyTorch BERT model."""
import logging
import os
import typing
......@@ -171,7 +170,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
else:
output_embeddings.weight = input_embeddings.weight
if hasattr(output_embeddings, "bias") and output_embeddings.bias is not None:
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = torch.nn.functional.pad(
output_embeddings.bias.data,
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
......@@ -242,7 +241,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# Copy word embeddings from the previous weights
# Copy token embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
......@@ -540,6 +539,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
model_to_load = getattr(model, cls.base_model_prefix)
load(model_to_load, prefix=start_prefix)
if model.__class__.__name__ != model_to_load.__class__.__name__:
base_model_state_dict = model_to_load.state_dict().keys()
head_model_state_dict_without_base_prefix = [
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
]
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
if len(missing_keys) > 0:
logger.info(
"Weights of {} not initialized from pretrained model: {}".format(
......@@ -558,14 +566,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
model.__class__.__name__, "\n\t".join(error_msgs)
)
)
model.tie_weights() # make sure word embedding weights are still tied if needed
model.tie_weights() # make sure token embedding weights are still tied if needed
# Set model in evaluation mode to desactivate DropOut modules by default
model.eval()
if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"error_msgs": error_msgs,
}
return model, loading_info
return model
......@@ -574,16 +585,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return {"input_ids": input_ids}
def _do_output_past(self, outputs):
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
if has_output_past and not has_mem_len and len(outputs) > 1:
return True
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
has_output_past = getattr(self.config, "output_past", False)
mem_len = getattr(self.config, "mem_len", 0)
if len(outputs) <= 1:
return False
if mem_len > 0 or has_output_past:
return True
return False
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
for i in range(batch_size * num_beams):
for previous_token in set(prev_output_tokens[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if lprobs[i, previous_token] < 0:
lprobs[i, previous_token] *= repetition_penalty
else:
lprobs[i, previous_token] /= repetition_penalty
@torch.no_grad()
def generate(
self,
......@@ -626,7 +646,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
temperature: (`optional`) float
The value used to module the next token probabilities. Must be strictely positive. Default to 1.0.
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
top_k: (`optional`) int
The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
......@@ -714,10 +734,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
assert temperature > 0, "`temperature` should be strictely positive."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
assert temperature > 0, "`temperature` should be strictly positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
......@@ -730,10 +750,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert (eos_token_ids is None) or (
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert length_penalty > 0, "`length_penalty` should be strictely positive."
assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0
), "`num_return_sequences` should be a strictely positive integer."
), "`num_return_sequences` should be a strictly positive integer."
if input_ids is None:
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
......@@ -746,6 +766,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
else:
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
# not allow to duplicate outputs when greedy decoding
if do_sample is False:
if num_beams == 1:
# no_beam_search greedy generation conditions
assert (
num_return_sequences == 1
), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
else:
# beam_search greedy generation conditions
assert (
num_beams >= num_return_sequences
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
if pad_token_id is None and eos_token_ids is not None:
logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
......@@ -756,15 +790,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
cur_len = input_ids.shape[1]
vocab_size = self.config.vocab_size
if num_return_sequences != 1:
# Expand input to num return sequences
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(
batch_size * num_return_sequences, cur_len
) # (batch_size * num_return_sequences, cur_len)
# set effective batch size and effective batch multiplier according to do_sample
if do_sample:
effective_batch_size = batch_size * num_return_sequences
effective_batch_mult = num_return_sequences
else:
effective_batch_size = batch_size
effective_batch_mult = 1
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
input_ids_len = input_ids.shape[-1]
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
input_ids = input_ids.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if num_beams > 1:
output = self._generate_beam_search(
......@@ -779,6 +819,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
pad_token_id,
eos_token_ids,
effective_batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
......@@ -817,14 +858,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
# current position / max lengths / length of generated sentences / unfinished sentences
# length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1)
sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
......@@ -834,13 +875,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if next_token_logits[i, previous_token] < 0:
next_token_logits[i, previous_token] *= repetition_penalty
else:
next_token_logits[i, previous_token] /= repetition_penalty
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
......@@ -872,12 +907,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# unfinished_sents is set to zero if eos in sentence
unfinished_sents.mul_((~eos_in_sents).long())
cur_len = cur_len + 1
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sents.max() == 0:
break
cur_len = cur_len + 1
# if there are different sentences lengths in the batch, some batches have to be padded
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
......@@ -904,15 +939,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
pad_token_id,
eos_token_ids,
batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
):
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
# generated hypotheses
generated_hyps = [
......@@ -921,7 +954,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
# Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
if do_sample is False:
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states
......@@ -933,7 +968,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
......@@ -941,42 +976,53 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size * num_beams):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if scores[i, previous_token] < 0:
scores[i, previous_token] *= repetition_penalty
else:
scores[i, previous_token] /= repetition_penalty
self.enforce_repetition_penalty_(
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty
)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
scores = scores / temperature
next_token_logits = next_token_logits / temperature
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# Top-p/top-k filtering
scores = top_k_top_p_filtering(
scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
_scores = top_k_top_p_filtering(
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size)
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
next_words = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2) # (batch_size * num_beams, 2)
# re-organize to group the beam together to sample from all beam_idxs
_scores = _scores.contiguous().view(
batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
next_tokens = torch.multinomial(
F.softmax(_scores, dim=-1), num_samples=2 * num_beams
) # (batch_size, num_beams * 2)
# Compute next scores
_scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = torch.gather(_scores, -1, next_words) # (batch_size * num_beams, 2)
next_scores = _scores + beam_scores[:, None].expand_as(_scores) # (batch_size * num_beams, 2)
# Match shape of greedy beam search
next_words = next_words.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
next_scores = next_scores.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
# sort the sampled vector to make sure that the first num_beams samples are the best
next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
else:
# do greedy beam search
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
assert scores.size() == (batch_size * num_beams, vocab_size)
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
next_scores, next_words = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
next_scores = next_scores.view(
batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
# next batch beam content
# list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
......@@ -1002,21 +1048,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# next sentence beam content
next_sent_beam = []
# next words for this sentence
for idx, score in zip(next_words[batch_idx], next_scores[batch_idx]):
# next tokens for this sentence
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
# get beam and word IDs
beam_id = idx // vocab_size
word_id = idx % vocab_size
token_id = idx % vocab_size
# add to generated hypotheses if end of sentence or last iteration
if eos_token_ids is not None and word_id.item() in eos_token_ids:
effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence
if eos_token_ids is not None and token_id.item() in eos_token_ids:
generated_hyps[batch_idx].add(
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
input_ids[effective_beam_id].clone(), score.item(),
)
else:
# add next predicted word if it is not eos_token
next_sent_beam.append((score, word_id, batch_idx * num_beams + beam_id))
next_sent_beam.append((score, token_id, effective_beam_id))
# the beam for next step is full
if len(next_sent_beam) == num_beams:
......@@ -1030,59 +1077,68 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_words = input_ids.new([x[1] for x in next_batch_beam])
beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order batch
input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
# re-order internal states
if past:
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
# update current length
cur_len = cur_len + 1
past = self._reorder_cache(past, beam_idx)
# stop when we are done with each sentence
if all(done):
break
# update current length
cur_len = cur_len + 1
# finalize all open beam hypotheses and end to generated hypotheses
for batch_idx in range(batch_size):
# Add all open beam hypothesis to generated_hyps
if not done[batch_idx]:
for idx, score in zip(next_words[batch_idx], next_scores[batch_idx]):
if done[batch_idx]:
continue
# get beam and word IDs
beam_id = idx // vocab_size
word_id = idx % vocab_size
generated_hyps[batch_idx].add(
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
)
# test that beam scores match previously calculated scores if not eos and batch_idx not done
if eos_token_ids is not None and all(
(token_id % vocab_size).item() not in eos_token_ids for token_id in next_tokens[batch_idx]
):
assert torch.all(
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx]
)
# need to add best num_beams hypotheses to generated hyps
for beam_id in range(num_beams):
effective_beam_id = batch_idx * num_beams + beam_id
final_score = beam_scores[effective_beam_id].item()
final_tokens = input_ids[effective_beam_id]
generated_hyps[batch_idx].add(final_tokens, final_score)
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
# select the best hypotheses
sent_lengths = input_ids.new(batch_size)
sent_lengths = input_ids.new(output_batch_size)
best = []
# retrieve best hypotheses
for i, hypotheses in enumerate(generated_hyps):
best_hyp = max(hypotheses.beams, key=lambda x: x[0])[1]
sent_lengths[i] = len(best_hyp)
best.append(best_hyp)
sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
for j in range(output_num_return_sequences_per_batch):
effective_batch_idx = output_num_return_sequences_per_batch * i + j
best_hyp = sorted_hyps.pop()[1]
sent_lengths[effective_batch_idx] = len(best_hyp)
best.append(best_hyp)
# shorter batches are filled with pad_token
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
decoded = input_ids.new(batch_size, sent_max_len).fill_(pad_token_id)
decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
# fill with hypothesis and eos_token_id if necessary
for i, hypo in enumerate(best):
......@@ -1096,6 +1152,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return decoded
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
return past
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
......@@ -1164,17 +1234,22 @@ class BeamHypotheses(object):
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs):
def is_done(self, best_sum_logprobs, cur_len=None):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
if cur_len is None:
cur_len = self.max_length
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
class Conv1D(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment