generation.py 5.35 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
3
4
from typing import Optional

Tri Dao's avatar
Tri Dao committed
5
6
from dataclasses import dataclass, field
import torch
7
from torch import Tensor
Tri Dao's avatar
Tri Dao committed
8
9
10

from einops import rearrange

11
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
Tri Dao's avatar
Tri Dao committed
12
13
14
15
16
17
18
19
20
21
22


@dataclass
class InferenceParams:
    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""
    max_sequence_len: int
    max_batch_size: int
    sequence_len_offset: int = 0
    batch_size_offset: int = 0
    key_value_memory_dict: dict = field(default_factory=dict)
23
24
    fused_ft_kernel: bool = False
    lengths_per_sample: Optional[Tensor] = None
Tri Dao's avatar
Tri Dao committed
25
26


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
    """Set the logits for none top-p values to -inf."""
    if top_p <= 0.0:
        return
    # First sort and calculate cumulative sum of probabilities.
    sorted_logits, sorted_indices = torch.sort(logits, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
     # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
    # scatter sorted tensors to original indexing
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(indices_to_remove, float('-inf'))


def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
    """Sample from top-k logits.
    Arguments:
        logits: Tensor of shape (batch_size, vocab_size)
    """
    if top_k == 1:  # Short-circuit for greedy decoding
        return logits.argmax(dim=-1)
    else:
        if top_p > 0.0:
            assert top_p <= 1.0, 'top-p should be in (0, 1].'
        if top_k > 0:
            top_k = min(top_k, logits.size(-1))  # Safety check
            logits_top, indices = torch.topk(logits, top_k, dim=-1)
            logits_top /= temperature
            modify_logits_for_top_p_filtering(logits_top, top_p)
            return indices[
                torch.arange(indices.shape[0], device=indices.device),
                torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
            ]
        else:
            logits_top = logits / temperature
            modify_logits_for_top_p_filtering(logits_top, top_p)
            return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)


def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, fused_ft_kernel=True):
    """Decoding, either greedy or with top-k or top-p sampling.
    If top-k = 0, don't limit the number of candidates (pure sampling).
    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
    then top-p.
Tri Dao's avatar
Tri Dao committed
73
    We assume that all sequences in the same batch have the same length.
74

Tri Dao's avatar
Tri Dao committed
75
76
77
    Arguments:
        input_ids: (batch, seq_len)
        max_length: int
78
    Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
Tri Dao's avatar
Tri Dao committed
79
80
81
82
        sequences: (batch, max_length)
        scores: tuples of (batch, vocab_size)
    """
    batch_size, seqlen_og = input_ids.shape
83
84
    inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
                                       fused_ft_kernel=fused_ft_kernel)
Tri Dao's avatar
Tri Dao committed
85
86
87
88
    scores = []
    with torch.inference_mode():
        logits = model(input_ids, inference_params=inference_params).logits[:, -1]
        scores.append(logits)
89
        next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
Tri Dao's avatar
Tri Dao committed
90
91
92
93
        sequences = [next_token]
        inference_params.sequence_len_offset = seqlen_og
        while True:
            position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
Tri Dao's avatar
Tri Dao committed
94
                                      dtype=torch.long, device=input_ids.device)
Tri Dao's avatar
Tri Dao committed
95
96
97
            logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
                           inference_params=inference_params).logits[:, -1]
            scores.append(logits)
98
            next_token = sample(logits, top_k=top_k, temperature=temperature)
Tri Dao's avatar
Tri Dao committed
99
100
101
102
            sequences.append(next_token)
            inference_params.sequence_len_offset += 1
            if inference_params.sequence_len_offset >= max_length - 1:
                break
103
104
    output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
    return output_cls(
Tri Dao's avatar
Tri Dao committed
105
106
107
108
109
110
111
        sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
        scores=tuple(scores)
    )


class GenerationMixin:

112
113
114
115
    def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
                 return_dict_in_generate=False, output_scores=False, **kwargs):
        output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
                        temperature=temperature, **kwargs)
Tri Dao's avatar
Tri Dao committed
116
117
118
        if not output_scores:
            output.scores = None
        return output if return_dict_in_generate else output.sequences