api.py 4.94 KB
Newer Older
mshoeybi's avatar
working  
mshoeybi committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Inference API."""


import torch

mshoeybi's avatar
mshoeybi committed
21
from megatron import mpu
mshoeybi's avatar
working  
mshoeybi committed
22
from .communication import broadcast_float_list
23
24
25
from .generation import (
        generate_tokens_probs_and_return_on_first_stage,
        score_and_return_on_first_stage)
mshoeybi's avatar
mshoeybi committed
26
27
28
29
30
from .tokenization import (
    tokenize_prompts,
    detokenize_generations)


mshoeybi's avatar
mshoeybi committed
31

mshoeybi's avatar
mshoeybi committed
32
33
34
35
def generate_and_post_process(model,
                              prompts=None,
                              tokens_to_generate=0,
                              return_output_log_probs=False,
mshoeybi's avatar
mshoeybi committed
36
37
                              top_k_sampling=0,
                              top_p_sampling=0.0,
mshoeybi's avatar
mshoeybi committed
38
                              temperature=1.0,
mshoeybi's avatar
mshoeybi committed
39
                              add_BOS=False,
40
41
                              use_eod_token_for_early_termination=True,
                              just_score=False):
mshoeybi's avatar
mshoeybi committed
42
    """Run inference and post-process outputs, i.e., detokenize,
mshoeybi's avatar
mshoeybi committed
43
    move to cpu and convert to list."""
mshoeybi's avatar
mshoeybi committed
44
45

    # Main inference.
46
    tokens, lengths, output_log_probs = generate(
mshoeybi's avatar
mshoeybi committed
47
48
49
50
        model,
        prompts=prompts,
        tokens_to_generate=tokens_to_generate,
        return_output_log_probs=return_output_log_probs,
mshoeybi's avatar
mshoeybi committed
51
52
        top_k_sampling=top_k_sampling,
        top_p_sampling=top_p_sampling,
mshoeybi's avatar
mshoeybi committed
53
        temperature=temperature,
mshoeybi's avatar
mshoeybi committed
54
        add_BOS=add_BOS,
55
56
        use_eod_token_for_early_termination=use_eod_token_for_early_termination,
        just_score=just_score)
mshoeybi's avatar
mshoeybi committed
57
58
59
60
61
62
63
64

    # Only post-process on first stage.
    if mpu.is_pipeline_first_stage():
        tokens, prompts_plus_generations, prompts_plus_generations_segments = \
            detokenize_generations(tokens, lengths, True)

        if return_output_log_probs:
            output_log_probs = output_log_probs.cpu().numpy().tolist()
65
66
            for i, (prob, seg) in enumerate(zip(output_log_probs, prompts_plus_generations_segments)):
                output_log_probs[i] = prob[:len(seg)-1]
mshoeybi's avatar
mshoeybi committed
67
68

        return prompts_plus_generations, prompts_plus_generations_segments, \
69
            output_log_probs, tokens
mshoeybi's avatar
mshoeybi committed
70
71

    return None
mshoeybi's avatar
working  
mshoeybi committed
72
73


mshoeybi's avatar
mshoeybi committed
74

mshoeybi's avatar
working  
mshoeybi committed
75
76
77
78
def generate(model,
             prompts=None,
             tokens_to_generate=0,
             return_output_log_probs=False,
mshoeybi's avatar
mshoeybi committed
79
80
             top_k_sampling=0,
             top_p_sampling=0.0,
mshoeybi's avatar
mshoeybi committed
81
             temperature=1.0,
mshoeybi's avatar
mshoeybi committed
82
             add_BOS=False,
83
84
             use_eod_token_for_early_termination=True,
             just_score=False):
mshoeybi's avatar
mshoeybi committed
85
86
87
88
89
90
91
    """Given prompts and input parameters, run inference and return:
       tokens: prompts plus the generated tokens.
       lengths: length of the prompt + generations. Note that we can
           discard tokens in the tokens tensor that are after the
           corresponding length.
       output_log_probs: log probs of the tokens.
    """
mshoeybi's avatar
working  
mshoeybi committed
92
93

    # Make sure input params are avaialble to all ranks.
mshoeybi's avatar
mshoeybi committed
94
    values = [tokens_to_generate,
95
              return_output_log_probs,
rprenger's avatar
rprenger committed
96
              top_k_sampling, top_p_sampling,
97
              temperature, add_BOS, use_eod_token_for_early_termination, just_score]
rprenger's avatar
rprenger committed
98
    values_float_tensor = broadcast_float_list(8, float_list=values)
mshoeybi's avatar
working  
mshoeybi committed
99
100
    tokens_to_generate = int(values_float_tensor[0].item())
    return_output_log_probs = bool(values_float_tensor[1].item())
mshoeybi's avatar
mshoeybi committed
101
102
103
104
105
    top_k_sampling = int(values_float_tensor[2].item())
    top_p_sampling = values_float_tensor[3].item()
    temperature = values_float_tensor[4].item()
    add_BOS = bool(values_float_tensor[5].item())
    use_eod_token_for_early_termination = bool(values_float_tensor[6].item())
rprenger's avatar
rprenger committed
106
    just_score = bool(values_float_tensor[7].item())
mshoeybi's avatar
working  
mshoeybi committed
107
108
109
110
111

    # Tokenize prompts and get the batch.
    # Note that these tensors are broadcaseted to all ranks.
    if torch.distributed.get_rank() == 0:
        assert prompts is not None
112
        #assert tokens_to_generate > 0
mshoeybi's avatar
working  
mshoeybi committed
113
    context_tokens_tensor, context_length_tensor = tokenize_prompts(
mshoeybi's avatar
mshoeybi committed
114
        prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
mshoeybi's avatar
working  
mshoeybi committed
115

116
117
118
119
    if just_score:
        return score_and_return_on_first_stage(
            model, context_tokens_tensor, context_length_tensor)

mshoeybi's avatar
working  
mshoeybi committed
120
121
122
123
124
    # Main inference function.
    # Note that the outputs are available on the first stage.
    return generate_tokens_probs_and_return_on_first_stage(
        model, context_tokens_tensor, context_length_tensor,
        return_output_log_probs=return_output_log_probs,
mshoeybi's avatar
mshoeybi committed
125
126
        top_k=top_k_sampling,
        top_p=top_p_sampling,
mshoeybi's avatar
mshoeybi committed
127
        temperature=temperature,
128
        use_eod_token_for_early_termination=use_eod_token_for_early_termination)