Commit 2eea6216 authored by rprenger's avatar rprenger
Browse files

Merging with main and fixing merge conflict

parents ed6806ac 5f694372
This diff is collapsed.
...@@ -122,8 +122,10 @@ def get_args(): ...@@ -122,8 +122,10 @@ def get_args():
choices=['lazy', 'cached', 'mmap']) choices=['lazy', 'cached', 'mmap'])
group = parser.add_argument_group(title='runtime') group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, default=1, group.add_argument('--workers', type=int, required=True,
help='Number of worker processes to launch') help='Number of worker processes to launch')
group.add_argument('--chunk-size', type=int, required=True,
help='Chunk size assigned to each worker process')
group.add_argument('--log-interval', type=int, default=100, group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates') help='Interval between progress updates')
args = parser.parse_args() args = parser.parse_args()
...@@ -154,7 +156,7 @@ def main(): ...@@ -154,7 +156,7 @@ def main():
encoder = Encoder(args) encoder = Encoder(args)
tokenizer = build_tokenizer(args) tokenizer = build_tokenizer(args)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 25) encoded_docs = pool.imap(encoder.encode, fin, args.chunk_size)
#encoded_docs = map(encoder.encode, fin) #encoded_docs = map(encoder.encode, fin)
level = "document" level = "document"
......
...@@ -28,6 +28,7 @@ from megatron.model import GPTModel ...@@ -28,6 +28,7 @@ from megatron.model import GPTModel
from megatron.training import get_model from megatron.training import get_model
from megatron.text_generation_server import MegatronServer from megatron.text_generation_server import MegatronServer
from megatron.text_generation import generate_and_post_process from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process
import torch import torch
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
...@@ -82,3 +83,8 @@ if __name__ == "__main__": ...@@ -82,3 +83,8 @@ if __name__ == "__main__":
generate_and_post_process(model) generate_and_post_process(model)
except ValueError as ve: except ValueError as ve:
pass pass
elif choice[0].item() == 1:
try:
beam_search_and_post_process(model)
except ValueError as ve:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment