"example/vscode:/vscode.git/clone" did not exist on "b5b85620b85860d0e93cecb859f0785966270cbf"
Commit 2eea6216 authored by rprenger's avatar rprenger
Browse files

Merging with main and fixing merge conflict

parents ed6806ac 5f694372
import os
import sys
import json
import argparse
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from megatron.data import indexed_dataset
def main(args):
prefixes = set()
for basename in os.listdir(args.input):
prefix, ext = os.path.splitext(basename)
if prefix in prefixes:
continue
if not os.path.isfile(os.path.join(args.input, basename)):
continue
ext_pair = '.bin' if ext == '.idx' else '.idx'
assert os.path.isfile(os.path.join(args.input, prefix) + ext_pair), \
f'ERROR: {ext_pair} file not provided for {os.path.join(args.input, prefix)}'
prefixes.add(prefix)
builder = None
for prefix in sorted(prefixes):
if builder is None:
dataset = indexed_dataset.make_dataset(os.path.join(args.input, prefix), 'infer')
if isinstance(dataset, indexed_dataset.MMapIndexedDataset):
builder = indexed_dataset.MMapIndexedDatasetBuilder(args.output_prefix + '.bin', dtype=dataset._index.dtype)
else:
builder = indexed_dataset.IndexedDatasetBuilder(args.output_prefix + '.bin')
del dataset
builder.merge_file_(os.path.join(args.input, prefix))
builder.finalize(args.output_prefix + '.idx')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
group.add_argument('--input', type=str, required=True,
help='Path to directory containing all document files to merge')
group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
args = parser.parse_args()
assert os.path.isdir(args.input), \
f'ERROR: {args.input} is not a directory or does not exist'
assert os.path.isdir(os.path.dirname(args.output_prefix)), \
f'ERROR: {os.path.dirname(args.output_prefix)} is not a directory or does not exist'
main(args)
...@@ -122,8 +122,10 @@ def get_args(): ...@@ -122,8 +122,10 @@ def get_args():
choices=['lazy', 'cached', 'mmap']) choices=['lazy', 'cached', 'mmap'])
group = parser.add_argument_group(title='runtime') group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, default=1, group.add_argument('--workers', type=int, required=True,
help='Number of worker processes to launch') help='Number of worker processes to launch')
group.add_argument('--chunk-size', type=int, required=True,
help='Chunk size assigned to each worker process')
group.add_argument('--log-interval', type=int, default=100, group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates') help='Interval between progress updates')
args = parser.parse_args() args = parser.parse_args()
...@@ -154,7 +156,7 @@ def main(): ...@@ -154,7 +156,7 @@ def main():
encoder = Encoder(args) encoder = Encoder(args)
tokenizer = build_tokenizer(args) tokenizer = build_tokenizer(args)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 25) encoded_docs = pool.imap(encoder.encode, fin, args.chunk_size)
#encoded_docs = map(encoder.encode, fin) #encoded_docs = map(encoder.encode, fin)
level = "document" level = "document"
......
...@@ -28,6 +28,7 @@ from megatron.model import GPTModel ...@@ -28,6 +28,7 @@ from megatron.model import GPTModel
from megatron.training import get_model from megatron.training import get_model
from megatron.text_generation_server import MegatronServer from megatron.text_generation_server import MegatronServer
from megatron.text_generation import generate_and_post_process from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process
import torch import torch
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
...@@ -82,3 +83,8 @@ if __name__ == "__main__": ...@@ -82,3 +83,8 @@ if __name__ == "__main__":
generate_and_post_process(model) generate_and_post_process(model)
except ValueError as ve: except ValueError as ve:
pass pass
elif choice[0].item() == 1:
try:
beam_search_and_post_process(model)
except ValueError as ve:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment