prompt.py 6.8 KB
Newer Older
zihanl's avatar
zihanl committed
1

zihanl's avatar
zihanl committed
2
3
"""Prompting the pretrained language model to generate knowledge/response"""

zihanl's avatar
zihanl committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import json
import torch
from nltk import word_tokenize
from megatron import mpu
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from tasks.knwl_dialo.utils import get_token_stream


def model_provider(pre_process=True, post_process=True):
    """Build the model."""

    print_rank_0('building GPT model ...')
    model = GPTModel(
        num_tokentypes=0,
        parallel_output=True,
        pre_process=pre_process,
        post_process=post_process
    )
    return model


def generate_samples_by_prompting_input_from_file(model):
zihanl's avatar
zihanl committed
32
33
34
    """Prompt a pretrained language model to generate knowledge/response"""
    
    # get tokenizer
zihanl's avatar
zihanl committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    args = get_args()
    tokenizer = get_tokenizer()

    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
            print('`sample-output-file` not specified, setting '
                    'it to {}'.format(sample_output_file))
        else:
            sample_output_file = args.sample_output_file

        fname_out = open(sample_output_file, "w")

zihanl's avatar
zihanl committed
54
55
56
57
    # only two prompt types (i.e., knowledge and response) are allowed
    assert args.prompt_type in ["knowledge", "response"], \
                "Please input a correct prompt type!"

zihanl's avatar
zihanl committed
58
    # Read the prompt file
zihanl's avatar
zihanl committed
59
60
    if args.prompt_type == "knowledge":
        # read the prompts for the knowledge generation
zihanl's avatar
zihanl committed
61
62
63
64
65
66
        prompt_examples_dict = {}
        with open(args.prompt_file, "r") as f:
            for i, line in enumerate(f):
                line = line.strip()
                line_dict = json.loads(line)
                key = list(line_dict.keys())[0]
zihanl's avatar
zihanl committed
67
68

                # get the prompt examples based on the key
zihanl's avatar
zihanl committed
69
70
71
72
73
74
75
76
77
                if key not in prompt_examples_dict:
                    prompt_examples = line_dict[key]
                    prompt = ""
                    for instance in prompt_examples:
                        instance = instance.strip()
                        prompt += instance + " \n"
                    prompt_examples_dict[key] = prompt

    else:
zihanl's avatar
zihanl committed
78
        # read the prompts for the response generation
zihanl's avatar
zihanl committed
79
        # prompts are fixed for all test samples
zihanl's avatar
zihanl committed
80
81
82
83
84
85
86
87
88
89
        with open(args.prompt_file, "r") as f:
            prompt_examples = f.readlines()
            prompt_examples = prompt_examples[:args.num_prompt_examples]

            prompt = ""
            for instance in prompt_examples:
                instance = instance.strip()
                prompt += instance + " \n"

    context_count = 0
zihanl's avatar
zihanl committed
90
    input_pos = 0
zihanl's avatar
zihanl committed
91
    model.eval()
zihanl's avatar
zihanl committed
92
    # perform prompting
zihanl's avatar
zihanl committed
93
94
95
96
97
98
99
100
    with torch.no_grad():
        while True:
            raw_text_len = 0
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
                input_str = all_raw_text[input_pos]
                input_str = input_str.strip()
                splits = input_str.split("\t")
zihanl's avatar
zihanl committed
101
                topic = splits[0]
zihanl's avatar
zihanl committed
102

zihanl's avatar
zihanl committed
103
104
                if args.prompt_type == "knowledge":
                    # first add the prompt into the raw_text
zihanl's avatar
zihanl committed
105
106
107
108
109
                    turns = splits[1].split(" [SEP] ")
                    last_turn = turns[-1]
                    key = topic + " " + last_turn
                    raw_text = prompt_examples_dict[key]

zihanl's avatar
zihanl committed
110
                    # construct inputs for knowledge generation
zihanl's avatar
zihanl committed
111
                    # then add the constructed inputs into the raw_text
zihanl's avatar
zihanl committed
112
113
                    turns = splits[1].split(" [SEP] ")
                    context = turns[-1]
zihanl's avatar
zihanl committed
114
                    raw_text += "( " + context + " ) " + topic + " =>"
zihanl's avatar
zihanl committed
115
116
                
                else:
zihanl's avatar
zihanl committed
117
118
119
                    # first add the prompt into the raw_text
                    raw_text = prompt

zihanl's avatar
zihanl committed
120
                    # construct inputs for response generation
zihanl's avatar
zihanl committed
121
                    # then add the constructed inputs into the raw_text
zihanl's avatar
zihanl committed
122
123
124
                    turns = splits[1].split(" [SEP] ")
                    knowledge = splits[2]
                    last_turn = turns[-1]
root's avatar
root committed
125
126
                    last_turn = " ".join(word_tokenize(last_turn))
                    knowledge = " ".join(word_tokenize(knowledge))
zihanl's avatar
zihanl committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
                    knowledge = knowledge.strip()
                    last_turn = last_turn.strip()
                    raw_text += "Topic: " + topic + ". "
                    raw_text += "User says: " + last_turn + " "
                    raw_text += "We know that: " + knowledge + " "
                    raw_text += "System replies:"

                input_pos += 1
                raw_text_len = len(raw_text)
                context_tokens = tokenizer.tokenize(raw_text)
            
            else:
                context_tokens = tokenizer.tokenize("EMPTY TEXT")

            if input_pos % 100 == 0:
                print_rank_0("input_pos: %d" % input_pos)

zihanl's avatar
zihanl committed
144
            # get the generation outputs (in decode_tokens)
zihanl's avatar
zihanl committed
145
146
147
148
            token_stream = get_token_stream(model, [context_tokens])
            for _, decode_tokens in enumerate(token_stream):
                pass
            
zihanl's avatar
zihanl committed
149
            # write the generated output to the output file
zihanl's avatar
zihanl committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
            if mpu.get_tensor_model_parallel_rank() == 0:
                if mpu.is_pipeline_first_stage():
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]
                    
                    generated_output = trim_decode_tokens.split("\n")[0]
                    generated_output = generated_output.strip()
                    fname_out.write(generated_output)
                    fname_out.write("\n")

            raw_text = None
            context_count += 1
            if input_pos == input_count:
                return


def main():

    args = get_args()
    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for text generation.")
        exit()

    # Set up model and load checkpoint.
    model = get_model(model_provider)
    if args.load is not None:
        _ = load_checkpoint(model, None, None)

    assert len(model) == 1, "Above condition should have caught this"
    model = model[0]

zihanl's avatar
zihanl committed
183
    # perform the prompting
zihanl's avatar
zihanl committed
184
    generate_samples_by_prompting_input_from_file(model)