control_dialog_interactive.py 5.71 KB
Newer Older
zihanl's avatar
zihanl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

"""Sample Generate Controllable Dialog Model"""

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))
import argparse
import torch
from transformers import DPRQuestionEncoderTokenizer
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.text_generation_utils import dialog_with_gpt_control_interactive, dialog_with_dpr_control_interactive


def model_provider(pre_process=True, post_process=True):
    """Build the model."""

    print_rank_0('building GPT model ...')
    model = GPTModel(num_tokentypes=0, parallel_output=False,
                     pre_process=pre_process, post_process=post_process)

    return model


def add_control_dialog_generate_args(parser):
    """Text generation arguments."""
    group = parser.add_argument_group(title='text generation')

    group.add_argument("--temperature", type=float, default=1.0,
                       help='Sampling temperature.')
    group.add_argument("--greedy", action='store_true', default=False,
                       help='Use greedy sampling.')
    group.add_argument("--top_p", type=float, default=0.0,
                       help='Top p sampling.')
    group.add_argument("--top_k", type=int, default=0,
                       help='Top k sampling.')
    group.add_argument("--out-seq-length", type=int, default=1024,
                       help='Size of the output generated text.')
    group.add_argument("--recompute", action='store_true',
                       help='During generation recompute all attention '
                       'instead of using previously computed keys/values.')
    group.add_argument("--ctrl-type", type=str, default="", 
                        help="Either dpr or gpt")
    group.add_argument("--ctrl-hidden-size", type=int, default=1024, 
                        help="hidden-size of gpt control model")
    group.add_argument("--ctrl-num-layers", type=int, default=24, 
                        help="num-layers of gpt control model")
    group.add_argument("--ctrl-num-attention-heads", type=int, default=16,
                        help="num-attention-heads of gpt control model")
    group.add_argument("--ctrl-gpt-load", type=str, default="",
                        help="checkpoint path of the gpt control model")
    group.add_argument("--ctrl-dpr-load", type=str, default="",
                        help="checkpoint path of the dpr control model")
    group.add_argument("--knowledge-corpus-path", type=str, default="",
                        help="The path for the knowledge corpus")
    group.add_argument("--knowledge-corpus-emb", type=str, default="",
                        help="The path for the knowledge embedding")                 
    group.add_argument('--spec-toks', type=str, default=None,
                        help='additional special tokens')
    group.add_argument('--add-separator', action="store_true",
                        help='Add separator for the inputs')
    
    return parser


def main():
    """Main program."""

    initialize_megatron(extra_args_provider=add_control_dialog_generate_args,
                        args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
                                       'no_load_rng': True,
                                       'no_load_optim': True})

    args = get_args()
    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for text generation.")
        exit()

    # Set up conversational model
    conv_model = get_model(model_provider)
    if args.load is not None:
        _ = load_checkpoint(conv_model, None, None)

    assert len(conv_model) == 1, "Above condition should have caught this"
    conv_model = conv_model[0]

    # Set up control model
    assert args.ctrl_type in ["gpt", "dpr"], \
                "please input a correct control model type"
    
    if args.ctrl_type == "gpt":
        args.consumed_train_samples = 0
        args.consumed_valid_samples = 0
        args.hidden_size = args.ctrl_hidden_size
        args.ffn_hidden_size = 4 * args.hidden_size
        args.num_layers = args.ctrl_num_layers
        args.num_attention_heads = args.ctrl_num_attention_heads
        args.load = args.ctrl_gpt_load

        ctrl_model = get_model(model_provider)
        if args.load is not None:
            _ = load_checkpoint(ctrl_model, None, None)
        ctrl_model = ctrl_model[0]
        
        dialog_with_gpt_control_interactive(conv_model, ctrl_model, args.add_separator)

    else:
        print_rank_0("> Loading model from %s" % args.ctrl_dpr_load)
        ctrl_model = torch.load(args.ctrl_dpr_load)
        ctrl_model.cuda()
        ctrl_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
        
        print_rank_0("> Loading knowledge corpus and embeddings")
        with open(args.knowledge_corpus_path, "r") as f:
            knowledge_corpus = f.readlines()
        knowledge_corpus_emb = torch.load(args.knowledge_corpus_emb)
        knowledge_corpus_emb = knowledge_corpus_emb.cuda()

        assert knowledge_corpus_emb.size()[0] == len(knowledge_corpus), \
            "The size of knowledge corpus and embeddings should be the same"

        dialog_with_dpr_control_interactive(conv_model, ctrl_model,
                                            ctrl_tokenizer, knowledge_corpus, 
                                            knowledge_corpus_emb, args.add_separator)


if __name__ == "__main__":

    main()