Commit 7122b878 authored by zihanl's avatar zihanl
Browse files

delete interactive file

parent 951ef3ae
"""Sample Generate Controllable Dialog Model"""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import argparse
import torch
from transformers import DPRQuestionEncoderTokenizer
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.text_generation_utils import dialog_with_gpt_control_interactive, dialog_with_dpr_control_interactive
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(num_tokentypes=0, parallel_output=False,
pre_process=pre_process, post_process=post_process)
return model
def add_control_dialog_generate_args(parser):
"""Text generation arguments."""
group = parser.add_argument_group(title='text generation')
group.add_argument("--temperature", type=float, default=1.0,
help='Sampling temperature.')
group.add_argument("--greedy", action='store_true', default=False,
help='Use greedy sampling.')
group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0,
help='Top k sampling.')
group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.')
group.add_argument("--recompute", action='store_true',
help='During generation recompute all attention '
'instead of using previously computed keys/values.')
group.add_argument("--ctrl-type", type=str, default="",
help="Either dpr or gpt")
group.add_argument("--ctrl-hidden-size", type=int, default=1024,
help="hidden-size of gpt control model")
group.add_argument("--ctrl-num-layers", type=int, default=24,
help="num-layers of gpt control model")
group.add_argument("--ctrl-num-attention-heads", type=int, default=16,
help="num-attention-heads of gpt control model")
group.add_argument("--ctrl-gpt-load", type=str, default="",
help="checkpoint path of the gpt control model")
group.add_argument("--ctrl-dpr-load", type=str, default="",
help="checkpoint path of the dpr control model")
group.add_argument("--knowledge-corpus-path", type=str, default="",
help="The path for the knowledge corpus")
group.add_argument("--knowledge-corpus-emb", type=str, default="",
help="The path for the knowledge embedding")
group.add_argument('--spec-toks', type=str, default=None,
help='additional special tokens')
group.add_argument('--add-separator', action="store_true",
help='Add separator for the inputs')
return parser
def main():
"""Main program."""
initialize_megatron(extra_args_provider=add_control_dialog_generate_args,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
'no_load_rng': True,
'no_load_optim': True})
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
# Set up conversational model
conv_model = get_model(model_provider)
if args.load is not None:
_ = load_checkpoint(conv_model, None, None)
assert len(conv_model) == 1, "Above condition should have caught this"
conv_model = conv_model[0]
# Set up control model
assert args.ctrl_type in ["gpt", "dpr"], \
"please input a correct control model type"
if args.ctrl_type == "gpt":
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
args.hidden_size = args.ctrl_hidden_size
args.ffn_hidden_size = 4 * args.hidden_size
args.num_layers = args.ctrl_num_layers
args.num_attention_heads = args.ctrl_num_attention_heads
args.load = args.ctrl_gpt_load
ctrl_model = get_model(model_provider)
if args.load is not None:
_ = load_checkpoint(ctrl_model, None, None)
ctrl_model = ctrl_model[0]
dialog_with_gpt_control_interactive(conv_model, ctrl_model, args.add_separator)
else:
print_rank_0("> Loading model from %s" % args.ctrl_dpr_load)
ctrl_model = torch.load(args.ctrl_dpr_load)
ctrl_model.cuda()
ctrl_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
print_rank_0("> Loading knowledge corpus and embeddings")
with open(args.knowledge_corpus_path, "r") as f:
knowledge_corpus = f.readlines()
knowledge_corpus_emb = torch.load(args.knowledge_corpus_emb)
knowledge_corpus_emb = knowledge_corpus_emb.cuda()
assert knowledge_corpus_emb.size()[0] == len(knowledge_corpus), \
"The size of knowledge corpus and embeddings should be the same"
dialog_with_dpr_control_interactive(conv_model, ctrl_model,
ctrl_tokenizer, knowledge_corpus,
knowledge_corpus_emb, args.add_separator)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment