Commit 051f58f1 authored by liangjing's avatar liangjing
Browse files

v1

parent 0024a5c6
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import hashlib
import os
from megatron import get_retro_args
def get_query_workdir():
args = get_retro_args()
return os.path.join(args.retro_workdir, "query")
def get_neighbor_dirname(key, dataset):
hashes = ",".join([ d.desc_hash for d in dataset.datasets ])
hash = hashlib.md5(hashes.encode()).hexdigest()
return os.path.join(get_query_workdir(), os.path.basename(f"{key}_{hash}"))
......@@ -8,6 +8,7 @@ from megatron import get_retro_args
from megatron.tokenizer.tokenizer import (
_BertWordPieceTokenizer,
_GPT2BPETokenizer,
_GPTSentencePieceTokenizer,
)
......@@ -28,10 +29,18 @@ def get_num_chunks_per_sample():
def get_gpt_tokenizer():
'''GPT (BPE) tokenizer.'''
args = get_retro_args()
return _GPT2BPETokenizer(
vocab_file=args.retro_gpt_vocab_file,
merge_file=args.retro_gpt_merge_file,
)
tokenizer_type = args.retro_gpt_tokenizer_type
if tokenizer_type == "GPT2BPETokenizer":
assert args.retro_gpt_vocab_file and args.retro_gpt_merge_file
return _GPT2BPETokenizer(
vocab_file=args.retro_gpt_vocab_file,
merge_file=args.retro_gpt_merge_file,
)
elif tokenizer_type == 'GPTSentencePieceTokenizer':
assert args.retro_gpt_tokenizer_model is not None
return _GPTSentencePieceTokenizer(args.retro_gpt_tokenizer_model)
else:
raise Exception("unrecognized gpt tokenizer, '%s'." % tokenizer_type)
def get_bert_tokenizer():
......
......@@ -13,6 +13,7 @@ from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.arguments import core_transformer_config_from_args
from megatron.text_generation_server import MegatronServer
from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process
......@@ -21,8 +22,10 @@ import torch
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
config = core_transformer_config_from_args(get_args())
print_rank_0('building GPT model ...')
model = GPTModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process)
model = GPTModel(config, num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process)
return model
......@@ -37,6 +40,8 @@ def add_text_generate_args(parser):
help='Top k sampling.')
group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.')
group.add_argument("--port", type=int, default=5000,
help='port for text generation server to run on')
return parser
......@@ -50,6 +55,9 @@ if __name__ == "__main__":
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
print_rank_0("WARNING: Forcing exit_on_missing_checkpoint to True for text "
"generation.")
args.exit_on_missing_checkpoint = True
# Set up model and load checkpoint
model = get_model(model_provider, wrap_with_ddp=False)
......@@ -60,7 +68,7 @@ if __name__ == "__main__":
model = model[0]
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
server = MegatronServer(model)
server.run("0.0.0.0")
server.run("0.0.0.0",port=args.port)
while True:
choice = torch.cuda.LongTensor(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment