Commit 77d39720 authored by thomwolf's avatar thomwolf
Browse files

clean up dead code

parent bbc0c86f
......@@ -830,146 +830,6 @@ class BeamHypotheses(object):
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
class Sampler(object):
r""" Sampler is used to generate sequences of ids from logit inputs.
Greedy decoding, which consists in chosing the most probable token at each
step, is the default behaviour. Sampling with varying temperature, top_k
and nucleus filtering is also implemented.
Attributes:
**device**: ``torch.device``
Device on which the computations will be run.
**do_sample**: bool
Whether to sample or do greedy decoding.
**k**: int between 0 and vocab_size
Parameter for the top-k filtering
**p**: float between 0 and 1
Parameter for the nucleus filtering
**temperature**: strictly positive float
Parameter used to modulate the distribution over ids. Low temperatures
put more emphasis on highly probably token while high temperatures tend
to smooth the probability distribution.
**repetition_penalty**: strictly postitive float
The penalty applied to repeating ids
"""
def __init__(
self, do_sample=False, k=9, p=0.0, temperature=1.0, repetition_penalty=1.0
):
self.k = k
self.p = p
self.do_sample = do_sample
self.temperature = temperature
self.repetition_penalty = repetition_penalty
self.do_apply_repetition_penalty = True if repetition_penalty > 1 else False
if self.p > 1:
warnings.warn(
"""You are trying to apply nucleus filtering with a value of p greater than 1 ({}).
However p is a probability and its value must lie between 0 and 1. In effect, no filtering
will be applied. If this is not the behavior you expect, change the value of p.""".format(
self.p
)
)
def get_one_token(self, next_token_logits, past_sequence):
logits = self.apply_repetition_penalty(next_token_logits, past_sequence)
if self.do_sample:
logits = self.apply_temperature(logits)
logits = self.apply_top_k_filter(logits)
logits = self.apply_nucleus_filter(logits)
return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
return torch.argmax(logits, dim=-1).unsqueeze(-1)
def apply_repetition_penalty(self, logits, past_sequence):
""" Apply a penalty to tokens that appear more than once in the
generated sequence.
.. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer
language model for controllable generation." arXiv preprint
arXiv:1909.05858 (2019).
"""
if self.do_apply_repetition_penalty:
generated_token_idx = set(past_sequence[0].tolist())
for token_idx in generated_token_idx:
logits[0, token_idx] /= self.repetition_penalty
return logits
def apply_temperature(self, logits):
""" Shape the tokens' distribution through temperature. The higher the value
of the temperature, the more skewed towards high probability events the
distribution is.
.. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning.
MIT press, 2016.
"""
# when dividing a float by 0, torch returns inf which in turns breaks the
# multinomial with an error message that is not very helpful. It is better
# for the user to break the execution and explain why.
if self.temperature == 0:
raise ZeroDivisionError(
"""You are trying to sample with a temperature equal to 0.
If you wanted to do greedy sampling, set instead `do_sample` to False.
Otherwise set the temperature to a value different from 0."""
)
return logits / self.temperature
def apply_top_k_filter(self, logits):
""" Use the probability distribution of the tokens to determine the set
to be sampled from. Specifically we select the set of size k such that
the sum of its items' probabilities is maximum.
.. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural
story generation." arXiv preprint arXiv:1805.04833 (2018).
"""
if self.k > 0:
vocabulary_size = logits.size(-1)
if self.k > vocabulary_size:
warnings.warn(
"""You provided a value for k ({}) that is larger than the vocabulary size ({}).
We adjusted k's value to the vocabulary size; if that was what you intended to do
we recommend setting k to 0 instead. It this is not the behavior you expected,
choose a value of k that is smaller than the vocabulary size.""".format(
self.k, vocabulary_size
)
)
self.k = vocabulary_size
indices_to_remove = logits < torch.topk(logits, self.k)[0][..., -1, None]
logits[indices_to_remove] = -float("Inf")
return logits
def apply_nucleus_filter(self, logits):
""" Use the probability distribution of the tokens to determine the set
to be sampled from. Specifically, choose the smallest set such that the
sum of its items' probabilities is greater than a number p in [0,1].
.. Holtzman, Ari, et al. "The curious case of neural text
degeneration." arXiv preprint arXiv:1904.09751 (2019).
"""
if self.p > 0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
sorted_probabilities = F.softmax(sorted_logits, dim=-1)
cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1)
# Remove tokens with cumulative probability above the threshold,
# but keep the first token above the threshold.
sorted_indices_to_remove = cumulative_probabilities > self.p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = -float("Inf")
return logits
class Conv1D(nn.Module):
def __init__(self, nf, nx):
""" Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment