api.py 9.9 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
mshoeybi's avatar
working  
mshoeybi committed
2
3
4
5
6
7

"""Inference API."""


import torch

8
from megatron.core import mpu
mshoeybi's avatar
working  
mshoeybi committed
9
from .communication import broadcast_float_list
10
11
from .generation import (
        generate_tokens_probs_and_return_on_first_stage,
rprenger's avatar
rprenger committed
12
13
        score_and_return_on_first_stage,
        beam_search_and_return_on_first_stage)
mshoeybi's avatar
mshoeybi committed
14
15
16
from .tokenization import (
    tokenize_prompts,
    detokenize_generations)
xingjinliang's avatar
xingjinliang committed
17
from .forward_step import ForwardStep
mshoeybi's avatar
mshoeybi committed
18
19

def generate_and_post_process(model,
xingjinliang's avatar
xingjinliang committed
20
                              forward_step=ForwardStep,
mshoeybi's avatar
mshoeybi committed
21
22
23
                              prompts=None,
                              tokens_to_generate=0,
                              return_output_log_probs=False,
mshoeybi's avatar
mshoeybi committed
24
25
                              top_k_sampling=0,
                              top_p_sampling=0.0,
26
27
                              top_p_decay=0.0,
                              top_p_bound=0.0,
mshoeybi's avatar
mshoeybi committed
28
                              temperature=1.0,
mshoeybi's avatar
mshoeybi committed
29
                              add_BOS=False,
30
31
                              use_eod_token_for_early_termination=True,
                              stop_on_double_eol=False,
32
                              stop_on_eol=False,
Peng Xu's avatar
Peng Xu committed
33
                              prevent_newline_after_colon=False,
xingjinliang's avatar
xingjinliang committed
34
35
36
37
                              random_seed=-1,
                              detokenize_segments=True,
                              data_parallel=False,
                              return_topk_logprobs=0):
mshoeybi's avatar
mshoeybi committed
38
    """Run inference and post-process outputs, i.e., detokenize,
xingjinliang's avatar
xingjinliang committed
39
40
41
42
43
44
45
    move to cpu and convert to list.

    Args:
        data_parallel (bool): Enable data parallel text generation. Note: Caller must ensure
            that 1) different data parallel model replicas are provided different prompts and
            2) outputs from the different model replicas are gathered.
    """
mshoeybi's avatar
mshoeybi committed
46
47

    # Main inference.
xingjinliang's avatar
xingjinliang committed
48
    tokens, lengths, output_log_probs, logprobs_topk = generate(
mshoeybi's avatar
mshoeybi committed
49
        model,
xingjinliang's avatar
xingjinliang committed
50
        forward_step=forward_step,
mshoeybi's avatar
mshoeybi committed
51
52
53
        prompts=prompts,
        tokens_to_generate=tokens_to_generate,
        return_output_log_probs=return_output_log_probs,
mshoeybi's avatar
mshoeybi committed
54
55
        top_k_sampling=top_k_sampling,
        top_p_sampling=top_p_sampling,
56
57
        top_p_decay=top_p_decay,
        top_p_bound=top_p_bound,
mshoeybi's avatar
mshoeybi committed
58
        temperature=temperature,
mshoeybi's avatar
mshoeybi committed
59
        add_BOS=add_BOS,
60
61
        use_eod_token_for_early_termination=use_eod_token_for_early_termination,
        stop_on_double_eol=stop_on_double_eol,
62
        stop_on_eol=stop_on_eol,
Peng Xu's avatar
Peng Xu committed
63
        prevent_newline_after_colon=prevent_newline_after_colon,
xingjinliang's avatar
xingjinliang committed
64
65
        random_seed=random_seed,
        data_parallel=data_parallel)
mshoeybi's avatar
mshoeybi committed
66
67
68
69

    # Only post-process on first stage.
    if mpu.is_pipeline_first_stage():
        tokens, prompts_plus_generations, prompts_plus_generations_segments = \
xingjinliang's avatar
xingjinliang committed
70
            detokenize_generations(tokens, lengths, detokenize_segments)
mshoeybi's avatar
mshoeybi committed
71
72
73

        if return_output_log_probs:
            output_log_probs = output_log_probs.cpu().numpy().tolist()
74
75
            for i, (prob, seg) in enumerate(zip(output_log_probs, prompts_plus_generations_segments)):
                output_log_probs[i] = prob[:len(seg)-1]
mshoeybi's avatar
mshoeybi committed
76

xingjinliang's avatar
xingjinliang committed
77
78
79
80
81
82
        if return_topk_logprobs > 0:
            assert tokens_to_generate == 0
            return prompts_plus_generations, prompts_plus_generations_segments, \
            output_log_probs, tokens, logprobs_topk
        else:
            return prompts_plus_generations, prompts_plus_generations_segments, \
83
            output_log_probs, tokens
mshoeybi's avatar
mshoeybi committed
84
85

    return None
mshoeybi's avatar
working  
mshoeybi committed
86
87

def generate(model,
xingjinliang's avatar
xingjinliang committed
88
             forward_step=None,
mshoeybi's avatar
working  
mshoeybi committed
89
90
91
             prompts=None,
             tokens_to_generate=0,
             return_output_log_probs=False,
mshoeybi's avatar
mshoeybi committed
92
93
             top_k_sampling=0,
             top_p_sampling=0.0,
94
95
             top_p_decay=0.0,
             top_p_bound=0.0,
mshoeybi's avatar
mshoeybi committed
96
             temperature=1.0,
mshoeybi's avatar
mshoeybi committed
97
             add_BOS=False,
98
99
             use_eod_token_for_early_termination=True,
             stop_on_double_eol=False,
100
             stop_on_eol=False,
Peng Xu's avatar
Peng Xu committed
101
             prevent_newline_after_colon=False,
xingjinliang's avatar
xingjinliang committed
102
103
104
105
106
107
108
109
             random_seed=-1,
             data_parallel=False):
    """Given prompts and input parameters, run inference.

    Args:
        data_parallel (bool): Enable data parallel text generation.

    Returns:
mshoeybi's avatar
mshoeybi committed
110
111
112
113
114
115
       tokens: prompts plus the generated tokens.
       lengths: length of the prompt + generations. Note that we can
           discard tokens in the tokens tensor that are after the
           corresponding length.
       output_log_probs: log probs of the tokens.
    """
mshoeybi's avatar
working  
mshoeybi committed
116
    # Make sure input params are avaialble to all ranks.
mshoeybi's avatar
mshoeybi committed
117
    values = [tokens_to_generate,
118
              return_output_log_probs,
119
              top_k_sampling, top_p_sampling, top_p_decay, top_p_bound,
120
121
              temperature, add_BOS, use_eod_token_for_early_termination,
              stop_on_double_eol,
122
              stop_on_eol,
Peng Xu's avatar
Peng Xu committed
123
              prevent_newline_after_colon,
124
              random_seed]
xingjinliang's avatar
xingjinliang committed
125
126

    values_float_tensor = broadcast_float_list(len(values), float_list=values, data_parallel=data_parallel)
mshoeybi's avatar
working  
mshoeybi committed
127
128
    tokens_to_generate = int(values_float_tensor[0].item())
    return_output_log_probs = bool(values_float_tensor[1].item())
mshoeybi's avatar
mshoeybi committed
129
130
    top_k_sampling = int(values_float_tensor[2].item())
    top_p_sampling = values_float_tensor[3].item()
131
132
    top_p_decay = values_float_tensor[4].item()
    top_p_bound = values_float_tensor[5].item()
133
134
135
136
137
    temperature = values_float_tensor[6].item()
    add_BOS = bool(values_float_tensor[7].item())
    use_eod_token_for_early_termination = bool(values_float_tensor[8].item())
    stop_on_double_eol = bool(values_float_tensor[9].item())
    stop_on_eol = bool(values_float_tensor[10].item())
Peng Xu's avatar
Peng Xu committed
138
139
    prevent_newline_after_colon = bool(values_float_tensor[11].item())
    random_seed = int(values_float_tensor[12].item())
140
141
142

    if random_seed != -1:
        torch.random.manual_seed(random_seed)
mshoeybi's avatar
working  
mshoeybi committed
143
144

    # Tokenize prompts and get the batch.
xingjinliang's avatar
xingjinliang committed
145
    # Note that these tensors are broadcasted to all ranks.
mshoeybi's avatar
working  
mshoeybi committed
146
147
    if torch.distributed.get_rank() == 0:
        assert prompts is not None
xingjinliang's avatar
xingjinliang committed
148

mshoeybi's avatar
working  
mshoeybi committed
149
    context_tokens_tensor, context_length_tensor = tokenize_prompts(
xingjinliang's avatar
xingjinliang committed
150
151
        prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS,
        data_parallel=data_parallel)
mshoeybi's avatar
working  
mshoeybi committed
152

153
    if tokens_to_generate == 0:
154
155
        return score_and_return_on_first_stage(
            model, context_tokens_tensor, context_length_tensor)
xingjinliang's avatar
xingjinliang committed
156

mshoeybi's avatar
working  
mshoeybi committed
157
158
159
    # Main inference function.
    # Note that the outputs are available on the first stage.
    return generate_tokens_probs_and_return_on_first_stage(
xingjinliang's avatar
xingjinliang committed
160
        model, forward_step, context_tokens_tensor, context_length_tensor,
mshoeybi's avatar
working  
mshoeybi committed
161
        return_output_log_probs=return_output_log_probs,
mshoeybi's avatar
mshoeybi committed
162
163
        top_k=top_k_sampling,
        top_p=top_p_sampling,
164
165
        top_p_decay=top_p_decay,
        top_p_bound=top_p_bound,
mshoeybi's avatar
mshoeybi committed
166
        temperature=temperature,
167
168
        use_eod_token_for_early_termination=use_eod_token_for_early_termination,
        stop_on_double_eol=stop_on_double_eol,
Peng Xu's avatar
Peng Xu committed
169
170
        stop_on_eol=stop_on_eol,
        prevent_newline_after_colon=prevent_newline_after_colon)
rprenger's avatar
rprenger committed
171
172

def beam_search_and_post_process(model,
xingjinliang's avatar
xingjinliang committed
173
                                 forward_step=ForwardStep,
rprenger's avatar
rprenger committed
174
175
176
                                 prompts=None,
                                 tokens_to_generate=0,
                                 beam_size=0,
177
178
                                 add_BOS=False,
                                 stop_token=50256,
179
                                 num_return_gen=1,
Peng Xu's avatar
Peng Xu committed
180
                                 length_penalty=1,
xingjinliang's avatar
xingjinliang committed
181
182
                                 prevent_newline_after_colon=False,
                                 detokenize_segments=True):
rprenger's avatar
rprenger committed
183
184
185
186
187
    """Run beam search and post-process outputs, i.e., detokenize,
    move to cpu and convert to list."""

    # Main inference.
    tokens, scores = beam_search(model,
xingjinliang's avatar
xingjinliang committed
188
                                 forward_step=forward_step,
rprenger's avatar
rprenger committed
189
190
191
                                 prompts=prompts,
                                 tokens_to_generate=tokens_to_generate,
                                 beam_size=beam_size,
192
193
                                 add_BOS=add_BOS,
                                 stop_token=stop_token,
194
                                 num_return_gen=num_return_gen,
Peng Xu's avatar
Peng Xu committed
195
196
                                 length_penalty=length_penalty,
                                 prevent_newline_after_colon=prevent_newline_after_colon)
rprenger's avatar
rprenger committed
197
198
    # Only post-process on first stage.
    if mpu.is_pipeline_first_stage():
xingjinliang's avatar
xingjinliang committed
199
200
        lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device())
        tokens, prompts_plus_generations, prompts_plus_generations_segments = detokenize_generations(tokens, lengths, detokenize_segments)
rprenger's avatar
rprenger committed
201
202
        scores = scores.cpu().numpy().tolist()
        return prompts_plus_generations, prompts_plus_generations_segments, scores
rprenger's avatar
rprenger committed
203
204
205

    return None

xingjinliang's avatar
xingjinliang committed
206
def beam_search(model, forward_step, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256, num_return_gen=1, length_penalty=1, prevent_newline_after_colon=False):
rprenger's avatar
rprenger committed
207
208
209
    # Make sure input params are avaialble to all ranks.
    values = [tokens_to_generate,
              beam_size,
210
211
212
              add_BOS,
              stop_token,
              num_return_gen,
Peng Xu's avatar
Peng Xu committed
213
214
215
              length_penalty,
              prevent_newline_after_colon]
    values_float_tensor = broadcast_float_list(len(values), float_list=values)
rprenger's avatar
rprenger committed
216
217
218
    tokens_to_generate = int(values_float_tensor[0].item())
    beam_size = int(values_float_tensor[1].item())
    add_BOS = bool(values_float_tensor[2].item())
219
220
221
    stop_token = int(values_float_tensor[3].item())
    num_return_gen = int(values_float_tensor[4].item())
    length_penalty = values_float_tensor[5].item()
Peng Xu's avatar
Peng Xu committed
222
    prevent_newline_after_colon = values_float_tensor[6].item()
rprenger's avatar
rprenger committed
223
224
225

    context_tokens_tensor, context_length_tensor = tokenize_prompts(
        prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
xingjinliang's avatar
xingjinliang committed
226
227

    return beam_search_and_return_on_first_stage(model, forward_step, context_tokens_tensor, context_length_tensor,
Peng Xu's avatar
Peng Xu committed
228
229
            beam_size, stop_token=stop_token, num_return_gen=num_return_gen, length_penalty=length_penalty,
            prevent_newline_after_colon=prevent_newline_after_colon)