Commit 3aca1415 authored by liangjing's avatar liangjing
Browse files

Merge branch 'megatron-lm_dtk24.04' into 'main'

Megatron lm dtk24.04

See merge request !1
parents 0024a5c6 1005e9d3
Pipeline #1806 passed with stage
...@@ -5,11 +5,12 @@ import os ...@@ -5,11 +5,12 @@ import os
import torch import torch
from megatron import get_args, get_retro_args from megatron import get_args, get_retro_args
from tools.bert_embedding.utils import get_index_path_map from tools.bert_embedding.utils import BlockPathMap
from tools.retro.db.utils import get_merged_train_dataset as get_db_dataset from tools.retro.db.utils import get_merged_train_dataset as get_db_dataset
from tools.retro.external_libs import h5py from tools.retro.external_libs import h5py
from .chunk_dataset import get_chunk_dataset_map from .chunk_dataset import get_chunk_dataset_map
from .utils import get_neighbor_dirname
class RetroDataset(torch.utils.data.Dataset): class RetroDataset(torch.utils.data.Dataset):
...@@ -100,7 +101,7 @@ class RetroDataset(torch.utils.data.Dataset): ...@@ -100,7 +101,7 @@ class RetroDataset(torch.utils.data.Dataset):
return sample return sample
def get_retro_datasets(): def get_retro_datasets(verify_sizes=True):
'''Get train, valid, test retro datasets.''' '''Get train, valid, test retro datasets.'''
args = get_args() args = get_args()
...@@ -116,24 +117,39 @@ def get_retro_datasets(): ...@@ -116,24 +117,39 @@ def get_retro_datasets():
chunk_dataset = chunk_ds_info["data"] chunk_dataset = chunk_ds_info["data"]
neighbor_dir = chunk_ds_info["neighbor_dir"] neighbor_dir = chunk_ds_info["neighbor_dir"]
neighbor_path_map = get_index_path_map(neighbor_dir) neighbor_path_map = BlockPathMap.from_dir(neighbor_dir,
retro_args.retro_block_size)
# Verify dataset prefixes. # Verify dataset prefixes.
sample_prefix = chunk_dataset.sample_dataset.datasets[0].index_prefix expected_dir = get_neighbor_dirname(data_key, chunk_dataset.sample_dataset)
neighbor_prefix = os.path.basename(neighbor_dir) assert expected_dir == neighbor_dir, \
assert sample_prefix == neighbor_prefix, \
"inconsistent dataset source; '%s' vs. '%s'." % \ "inconsistent dataset source; '%s' vs. '%s'." % \
(sample_prefix, neighbor_prefix) (expected_dir, neighbor_dir)
# Verify num chunks. # Verify num chunks.
n_sample_chunks = len(chunk_dataset) n_sample_chunks = len(chunk_dataset)
n_neighbor_chunks = len(neighbor_path_map.id_index_map) n_neighbor_chunks = neighbor_path_map.max_idx
if n_sample_chunks != n_neighbor_chunks: if not os.path.isdir(neighbor_dir):
print("neighbor_dir : %s" % neighbor_dir) if torch.distributed.get_rank() == 0:
print("neighbor_path_map : %s" % neighbor_path_map) raise Exception("neighbor directory '%s' not found; please "
raise Exception("num sampled chunks (%d) != num neighbor chunks (%d)" "compare --train-samples, --seq-length, --seed, "
% (n_sample_chunks, n_neighbor_chunks)) "--eval-iters, and --eval-interval, with "
"retro preprocessing args." %
neighbor_dir)
torch.distributed.barrier()
exit()
if verify_sizes and n_sample_chunks != n_neighbor_chunks:
if torch.distributed.get_rank() == 0:
print("neighbor_dir : %s" % neighbor_dir)
print("neighbor_path_map : %s" % neighbor_path_map)
raise Exception("num sampled chunks (%d) != num neighbor chunks "
"(%d); did you complete querying the entire "
"pretraining dataset?"
% (n_sample_chunks, n_neighbor_chunks))
torch.distributed.barrier()
exit()
# Retro dataset. # Retro dataset.
retro_dataset_map[data_key] = RetroDataset( retro_dataset_map[data_key] = RetroDataset(
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import hashlib
import os
from megatron import get_retro_args
def get_query_workdir():
args = get_retro_args()
return os.path.join(args.retro_workdir, "query")
def get_neighbor_dirname(key, dataset):
hashes = ",".join([ d.desc_hash for d in dataset.datasets ])
hash = hashlib.md5(hashes.encode()).hexdigest()
return os.path.join(get_query_workdir(), os.path.basename(f"{key}_{hash}"))
...@@ -8,6 +8,7 @@ from megatron import get_retro_args ...@@ -8,6 +8,7 @@ from megatron import get_retro_args
from megatron.tokenizer.tokenizer import ( from megatron.tokenizer.tokenizer import (
_BertWordPieceTokenizer, _BertWordPieceTokenizer,
_GPT2BPETokenizer, _GPT2BPETokenizer,
_GPTSentencePieceTokenizer,
) )
...@@ -28,10 +29,18 @@ def get_num_chunks_per_sample(): ...@@ -28,10 +29,18 @@ def get_num_chunks_per_sample():
def get_gpt_tokenizer(): def get_gpt_tokenizer():
'''GPT (BPE) tokenizer.''' '''GPT (BPE) tokenizer.'''
args = get_retro_args() args = get_retro_args()
return _GPT2BPETokenizer( tokenizer_type = args.retro_gpt_tokenizer_type
vocab_file=args.retro_gpt_vocab_file, if tokenizer_type == "GPT2BPETokenizer":
merge_file=args.retro_gpt_merge_file, assert args.retro_gpt_vocab_file and args.retro_gpt_merge_file
) return _GPT2BPETokenizer(
vocab_file=args.retro_gpt_vocab_file,
merge_file=args.retro_gpt_merge_file,
)
elif tokenizer_type == 'GPTSentencePieceTokenizer':
assert args.retro_gpt_tokenizer_model is not None
return _GPTSentencePieceTokenizer(args.retro_gpt_tokenizer_model)
else:
raise Exception("unrecognized gpt tokenizer, '%s'." % tokenizer_type)
def get_bert_tokenizer(): def get_bert_tokenizer():
......
...@@ -13,6 +13,7 @@ from megatron.checkpointing import load_checkpoint ...@@ -13,6 +13,7 @@ from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.model import GPTModel from megatron.model import GPTModel
from megatron.training import get_model from megatron.training import get_model
from megatron.arguments import core_transformer_config_from_args
from megatron.text_generation_server import MegatronServer from megatron.text_generation_server import MegatronServer
from megatron.text_generation import generate_and_post_process from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process from megatron.text_generation import beam_search_and_post_process
...@@ -21,8 +22,10 @@ import torch ...@@ -21,8 +22,10 @@ import torch
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
config = core_transformer_config_from_args(get_args())
print_rank_0('building GPT model ...') print_rank_0('building GPT model ...')
model = GPTModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process) model = GPTModel(config, num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process)
return model return model
...@@ -37,6 +40,8 @@ def add_text_generate_args(parser): ...@@ -37,6 +40,8 @@ def add_text_generate_args(parser):
help='Top k sampling.') help='Top k sampling.')
group.add_argument("--out-seq-length", type=int, default=1024, group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.') help='Size of the output generated text.')
group.add_argument("--port", type=int, default=5000,
help='port for text generation server to run on')
return parser return parser
...@@ -50,6 +55,9 @@ if __name__ == "__main__": ...@@ -50,6 +55,9 @@ if __name__ == "__main__":
if args.num_layers_per_virtual_pipeline_stage is not None: if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.") print("Interleaved pipeline schedule is not yet supported for text generation.")
exit() exit()
print_rank_0("WARNING: Forcing exit_on_missing_checkpoint to True for text "
"generation.")
args.exit_on_missing_checkpoint = True
# Set up model and load checkpoint # Set up model and load checkpoint
model = get_model(model_provider, wrap_with_ddp=False) model = get_model(model_provider, wrap_with_ddp=False)
...@@ -60,7 +68,7 @@ if __name__ == "__main__": ...@@ -60,7 +68,7 @@ if __name__ == "__main__":
model = model[0] model = model[0]
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
server = MegatronServer(model) server = MegatronServer(model)
server.run("0.0.0.0") server.run("0.0.0.0",port=args.port)
while True: while True:
choice = torch.cuda.LongTensor(1) choice = torch.cuda.LongTensor(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment