Commit ee7b19e7 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merge branch 'main' into main_dedup

parents d413bd5f f2d64c00
...@@ -35,8 +35,8 @@ def model_provider(): ...@@ -35,8 +35,8 @@ def model_provider():
"""Build the model.""" """Build the model."""
print_rank_0('building GPT model ...') print_rank_0('building GPT model ...')
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1: def model_provider_pipelined():
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = GPTModelFirstStage(num_tokentypes=0) model = GPTModelFirstStage(num_tokentypes=0)
...@@ -46,6 +46,17 @@ def model_provider(): ...@@ -46,6 +46,17 @@ def model_provider():
else: else:
model = GPTModelIntermediateStage( model = GPTModelIntermediateStage(
num_tokentypes=0) num_tokentypes=0)
return model
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model.append(model_provider_pipelined())
else:
model = model_provider_pipelined()
else: else:
model = GPTModel(num_tokentypes=0, parallel_output=True) model = GPTModel(num_tokentypes=0, parallel_output=True)
......
This diff is collapsed.
...@@ -47,6 +47,20 @@ def get_tasks_args(parser): ...@@ -47,6 +47,20 @@ def get_tasks_args(parser):
help='Sliding window for overlapping evaluation.') help='Sliding window for overlapping evaluation.')
group.add_argument('--strict-lambada', action='store_true', group.add_argument('--strict-lambada', action='store_true',
help='Use more difficult formulation of lambada.') help='Use more difficult formulation of lambada.')
# Retriever args
group.add_argument('--qa-data-dev', type=str, default=None,
help='Path to the QA dataset dev file.')
group.add_argument('--qa-data-test', type=str, default=None,
help='Path to the QA dataset test file.')
# Faiss arguments for retriever
group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU')
group.add_argument('--faiss-match', type=str, default='string', \
choices=['regex', 'string'], help="Answer matching '\
'logic type")
group.add_argument('--faiss-topk-retrievals', type=int, default=100,
help='Number of blocks to use as top-k during retrieval')
return parser return parser
...@@ -62,6 +76,8 @@ if __name__ == '__main__': ...@@ -62,6 +76,8 @@ if __name__ == '__main__':
from glue.finetune import main from glue.finetune import main
elif args.task in ['LAMBADA', 'WIKITEXT103']: elif args.task in ['LAMBADA', 'WIKITEXT103']:
from zeroshot_gpt.evaluate import main from zeroshot_gpt.evaluate import main
elif args.task in ['ICT-ZEROSHOT-NQ']:
from orqa.evaluate_orqa import main
else: else:
raise NotImplementedError('Task {} is not implemented.'.format( raise NotImplementedError('Task {} is not implemented.'.format(
args.task)) args.task))
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment