generation.py 2.58 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from dataclasses import dataclass, field
import torch

from einops import rearrange

from transformers.generation import GreedySearchDecoderOnlyOutput


@dataclass
class InferenceParams:
    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""
    max_sequence_len: int
    max_batch_size: int
    sequence_len_offset: int = 0
    batch_size_offset: int = 0
    key_value_memory_dict: dict = field(default_factory=dict)


def greedy_decode(input_ids, model, max_length):
    """Greedy decoding. This is a very simple implementation.
    We assume that all sequences in the same batch have the same length.
    Arguments:
        input_ids: (batch, seq_len)
        max_length: int
    Returns: GreedySearchDecoderOnlyOutput, with the following fields:
        sequences: (batch, max_length)
        scores: tuples of (batch, vocab_size)
    """
    batch_size, seqlen_og = input_ids.shape
    inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
    scores = []
    with torch.inference_mode():
        logits = model(input_ids, inference_params=inference_params).logits[:, -1]
        scores.append(logits)
        next_token = logits.argmax(dim=-1)
        sequences = [next_token]
        inference_params.sequence_len_offset = seqlen_og
        while True:
            position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
Tri Dao's avatar
Tri Dao committed
43
                                      dtype=torch.long, device=input_ids.device)
Tri Dao's avatar
Tri Dao committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
            logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
                           inference_params=inference_params).logits[:, -1]
            scores.append(logits)
            next_token = logits.argmax(dim=-1)
            sequences.append(next_token)
            inference_params.sequence_len_offset += 1
            if inference_params.sequence_len_offset >= max_length - 1:
                break
    return GreedySearchDecoderOnlyOutput(
        sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
        scores=tuple(scores)
    )


class GenerationMixin:

    def generate(self, input_ids, max_length, return_dict_in_generate=False, output_scores=False):
        output = greedy_decode(input_ids, self, max_length)
        if not output_scores:
            output.scores = None
        return output if return_dict_in_generate else output.sequences