Commit 3aca1415 authored by liangjing's avatar liangjing
Browse files

Merge branch 'megatron-lm_dtk24.04' into 'main'

Megatron lm dtk24.04

See merge request !1
parents 0024a5c6 1005e9d3
Pipeline #1806 passed with stage
......@@ -5,11 +5,12 @@ import os
import torch
from megatron import get_args, get_retro_args
from tools.bert_embedding.utils import get_index_path_map
from tools.bert_embedding.utils import BlockPathMap
from tools.retro.db.utils import get_merged_train_dataset as get_db_dataset
from tools.retro.external_libs import h5py
from .chunk_dataset import get_chunk_dataset_map
from .utils import get_neighbor_dirname
class RetroDataset(torch.utils.data.Dataset):
......@@ -100,7 +101,7 @@ class RetroDataset(torch.utils.data.Dataset):
return sample
def get_retro_datasets():
def get_retro_datasets(verify_sizes=True):
'''Get train, valid, test retro datasets.'''
args = get_args()
......@@ -116,24 +117,39 @@ def get_retro_datasets():
chunk_dataset = chunk_ds_info["data"]
neighbor_dir = chunk_ds_info["neighbor_dir"]
neighbor_path_map = get_index_path_map(neighbor_dir)
neighbor_path_map = BlockPathMap.from_dir(neighbor_dir,
retro_args.retro_block_size)
# Verify dataset prefixes.
sample_prefix = chunk_dataset.sample_dataset.datasets[0].index_prefix
neighbor_prefix = os.path.basename(neighbor_dir)
assert sample_prefix == neighbor_prefix, \
expected_dir = get_neighbor_dirname(data_key, chunk_dataset.sample_dataset)
assert expected_dir == neighbor_dir, \
"inconsistent dataset source; '%s' vs. '%s'." % \
(sample_prefix, neighbor_prefix)
(expected_dir, neighbor_dir)
# Verify num chunks.
n_sample_chunks = len(chunk_dataset)
n_neighbor_chunks = len(neighbor_path_map.id_index_map)
if n_sample_chunks != n_neighbor_chunks:
print("neighbor_dir : %s" % neighbor_dir)
print("neighbor_path_map : %s" % neighbor_path_map)
raise Exception("num sampled chunks (%d) != num neighbor chunks (%d)"
% (n_sample_chunks, n_neighbor_chunks))
n_neighbor_chunks = neighbor_path_map.max_idx
if not os.path.isdir(neighbor_dir):
if torch.distributed.get_rank() == 0:
raise Exception("neighbor directory '%s' not found; please "
"compare --train-samples, --seq-length, --seed, "
"--eval-iters, and --eval-interval, with "
"retro preprocessing args." %
neighbor_dir)
torch.distributed.barrier()
exit()
if verify_sizes and n_sample_chunks != n_neighbor_chunks:
if torch.distributed.get_rank() == 0:
print("neighbor_dir : %s" % neighbor_dir)
print("neighbor_path_map : %s" % neighbor_path_map)
raise Exception("num sampled chunks (%d) != num neighbor chunks "
"(%d); did you complete querying the entire "
"pretraining dataset?"
% (n_sample_chunks, n_neighbor_chunks))
torch.distributed.barrier()
exit()
# Retro dataset.
retro_dataset_map[data_key] = RetroDataset(
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import hashlib
import os
from megatron import get_retro_args
def get_query_workdir():
args = get_retro_args()
return os.path.join(args.retro_workdir, "query")
def get_neighbor_dirname(key, dataset):
hashes = ",".join([ d.desc_hash for d in dataset.datasets ])
hash = hashlib.md5(hashes.encode()).hexdigest()
return os.path.join(get_query_workdir(), os.path.basename(f"{key}_{hash}"))
......@@ -8,6 +8,7 @@ from megatron import get_retro_args
from megatron.tokenizer.tokenizer import (
_BertWordPieceTokenizer,
_GPT2BPETokenizer,
_GPTSentencePieceTokenizer,
)
......@@ -28,10 +29,18 @@ def get_num_chunks_per_sample():
def get_gpt_tokenizer():
'''GPT (BPE) tokenizer.'''
args = get_retro_args()
return _GPT2BPETokenizer(
vocab_file=args.retro_gpt_vocab_file,
merge_file=args.retro_gpt_merge_file,
)
tokenizer_type = args.retro_gpt_tokenizer_type
if tokenizer_type == "GPT2BPETokenizer":
assert args.retro_gpt_vocab_file and args.retro_gpt_merge_file
return _GPT2BPETokenizer(
vocab_file=args.retro_gpt_vocab_file,
merge_file=args.retro_gpt_merge_file,
)
elif tokenizer_type == 'GPTSentencePieceTokenizer':
assert args.retro_gpt_tokenizer_model is not None
return _GPTSentencePieceTokenizer(args.retro_gpt_tokenizer_model)
else:
raise Exception("unrecognized gpt tokenizer, '%s'." % tokenizer_type)
def get_bert_tokenizer():
......
......@@ -13,6 +13,7 @@ from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.arguments import core_transformer_config_from_args
from megatron.text_generation_server import MegatronServer
from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process
......@@ -21,8 +22,10 @@ import torch
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
config = core_transformer_config_from_args(get_args())
print_rank_0('building GPT model ...')
model = GPTModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process)
model = GPTModel(config, num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process)
return model
......@@ -37,6 +40,8 @@ def add_text_generate_args(parser):
help='Top k sampling.')
group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.')
group.add_argument("--port", type=int, default=5000,
help='port for text generation server to run on')
return parser
......@@ -50,6 +55,9 @@ if __name__ == "__main__":
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
print_rank_0("WARNING: Forcing exit_on_missing_checkpoint to True for text "
"generation.")
args.exit_on_missing_checkpoint = True
# Set up model and load checkpoint
model = get_model(model_provider, wrap_with_ddp=False)
......@@ -60,7 +68,7 @@ if __name__ == "__main__":
model = model[0]
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
server = MegatronServer(model)
server.run("0.0.0.0")
server.run("0.0.0.0",port=args.port)
while True:
choice = torch.cuda.LongTensor(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment