prompt.py 8.1 KB
Newer Older
zihanl's avatar
zihanl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
zihanl's avatar
zihanl committed
15

zihanl's avatar
zihanl committed
16
17
"""Prompting the pretrained language model to generate knowledge/response"""

zihanl's avatar
zihanl committed
18
19
20
21
22
23
24
25
26
27
28
29
import json
import torch
from nltk import word_tokenize
from megatron import mpu
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from tasks.knwl_dialo.utils import get_token_stream
zihanl's avatar
zihanl committed
30
# from megatron.text_generation import generate_and_post_process
zihanl's avatar
zihanl committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46


def model_provider(pre_process=True, post_process=True):
    """Build the model."""

    print_rank_0('building GPT model ...')
    model = GPTModel(
        num_tokentypes=0,
        parallel_output=True,
        pre_process=pre_process,
        post_process=post_process
    )
    return model


def generate_samples_by_prompting_input_from_file(model):
zihanl's avatar
zihanl committed
47
48
49
    """Prompt a pretrained language model to generate knowledge/response"""
    
    # get tokenizer
zihanl's avatar
zihanl committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    args = get_args()
    tokenizer = get_tokenizer()

    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
            print('`sample-output-file` not specified, setting '
                    'it to {}'.format(sample_output_file))
        else:
            sample_output_file = args.sample_output_file

        fname_out = open(sample_output_file, "w")

zihanl's avatar
zihanl committed
69
70
71
72
    # only two prompt types (i.e., knowledge and response) are allowed
    assert args.prompt_type in ["knowledge", "response"], \
                "Please input a correct prompt type!"

zihanl's avatar
zihanl committed
73
    # Read the prompt file
zihanl's avatar
zihanl committed
74
75
    if args.prompt_type == "knowledge":
        # read the prompts for the knowledge generation
zihanl's avatar
zihanl committed
76
77
78
79
80
81
        prompt_examples_dict = {}
        with open(args.prompt_file, "r") as f:
            for i, line in enumerate(f):
                line = line.strip()
                line_dict = json.loads(line)
                key = list(line_dict.keys())[0]
zihanl's avatar
zihanl committed
82
83

                # get the prompt examples based on the key
zihanl's avatar
zihanl committed
84
85
86
87
88
89
90
91
92
                if key not in prompt_examples_dict:
                    prompt_examples = line_dict[key]
                    prompt = ""
                    for instance in prompt_examples:
                        instance = instance.strip()
                        prompt += instance + " \n"
                    prompt_examples_dict[key] = prompt

    else:
zihanl's avatar
zihanl committed
93
        # read the prompts for the response generation
zihanl's avatar
zihanl committed
94
        # prompts are fixed for all test samples
zihanl's avatar
zihanl committed
95
96
97
98
99
100
101
102
103
104
        with open(args.prompt_file, "r") as f:
            prompt_examples = f.readlines()
            prompt_examples = prompt_examples[:args.num_prompt_examples]

            prompt = ""
            for instance in prompt_examples:
                instance = instance.strip()
                prompt += instance + " \n"

    context_count = 0
zihanl's avatar
zihanl committed
105
    input_pos = 0
zihanl's avatar
zihanl committed
106
    model.eval()
zihanl's avatar
zihanl committed
107
    # perform prompting
zihanl's avatar
zihanl committed
108
109
110
111
112
113
114
115
    with torch.no_grad():
        while True:
            raw_text_len = 0
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
                input_str = all_raw_text[input_pos]
                input_str = input_str.strip()
                splits = input_str.split("\t")
zihanl's avatar
zihanl committed
116
                topic = splits[0]
zihanl's avatar
zihanl committed
117

zihanl's avatar
zihanl committed
118
119
                if args.prompt_type == "knowledge":
                    # first add the prompt into the raw_text
zihanl's avatar
zihanl committed
120
121
122
123
124
                    turns = splits[1].split(" [SEP] ")
                    last_turn = turns[-1]
                    key = topic + " " + last_turn
                    raw_text = prompt_examples_dict[key]

zihanl's avatar
zihanl committed
125
                    # construct inputs for knowledge generation
zihanl's avatar
zihanl committed
126
                    # then add the constructed inputs into the raw_text
zihanl's avatar
zihanl committed
127
128
                    turns = splits[1].split(" [SEP] ")
                    context = turns[-1]
zihanl's avatar
zihanl committed
129
                    raw_text += "( " + context + " ) " + topic + " =>"
zihanl's avatar
zihanl committed
130
131
                
                else:
zihanl's avatar
zihanl committed
132
133
134
                    # first add the prompt into the raw_text
                    raw_text = prompt

zihanl's avatar
zihanl committed
135
                    # construct inputs for response generation
zihanl's avatar
zihanl committed
136
                    # then add the constructed inputs into the raw_text
zihanl's avatar
zihanl committed
137
138
139
                    turns = splits[1].split(" [SEP] ")
                    knowledge = splits[2]
                    last_turn = turns[-1]
root's avatar
root committed
140
141
                    last_turn = " ".join(word_tokenize(last_turn))
                    knowledge = " ".join(word_tokenize(knowledge))
zihanl's avatar
zihanl committed
142
143
144
145
146
147
148
149
150
151
152
153
154
                    knowledge = knowledge.strip()
                    last_turn = last_turn.strip()
                    raw_text += "Topic: " + topic + ". "
                    raw_text += "User says: " + last_turn + " "
                    raw_text += "We know that: " + knowledge + " "
                    raw_text += "System replies:"

                input_pos += 1
                raw_text_len = len(raw_text)
                context_tokens = tokenizer.tokenize(raw_text)
            
            else:
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
zihanl's avatar
zihanl committed
155
                # raw_text = "EMPTY TEXT"
zihanl's avatar
zihanl committed
156
157
158
159

            if input_pos % 100 == 0:
                print_rank_0("input_pos: %d" % input_pos)

zihanl's avatar
zihanl committed
160
            # get the generation outputs (in decode_tokens)
zihanl's avatar
zihanl committed
161
162
163
            token_stream = get_token_stream(model, [context_tokens])
            for _, decode_tokens in enumerate(token_stream):
                pass
zihanl's avatar
zihanl committed
164
165
166
167
168
169
170
            # outputs = generate_and_post_process(
            #             model=model, 
            #             prompts=[raw_text], 
            #             tokens_to_generate=args.out_seq_length,
            #             top_k_sampling=1)
            # prompts_plus_generations = outputs[0]

zihanl's avatar
zihanl committed
171
            # write the generated output to the output file
zihanl's avatar
zihanl committed
172
173
174
175
176
177
178
179
180
181
182
            if mpu.get_tensor_model_parallel_rank() == 0:
                if mpu.is_pipeline_first_stage():
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]
                    
                    generated_output = trim_decode_tokens.split("\n")[0]
                    generated_output = generated_output.strip()
                    fname_out.write(generated_output)
                    fname_out.write("\n")
zihanl's avatar
zihanl committed
183
184
185
186
187
188
                    
                    # generations = prompts_plus_generations[raw_text_len:]
                    # generations = generations.split("\n")[0]
                    # generations = generations.strip()
                    # fname_out.write(generations)
                    # fname_out.write("\n")
zihanl's avatar
zihanl committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

            raw_text = None
            context_count += 1
            if input_pos == input_count:
                return


def main():

    args = get_args()
    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for text generation.")
        exit()

    # Set up model and load checkpoint.
    model = get_model(model_provider)
    if args.load is not None:
        _ = load_checkpoint(model, None, None)

    assert len(model) == 1, "Above condition should have caught this"
    model = model[0]

zihanl's avatar
zihanl committed
211
    # perform the prompting
zihanl's avatar
zihanl committed
212
    generate_samples_by_prompting_input_from_file(model)