Commit 1aa2e08a authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'fix_zeroshot_bug' into 'main'

GPT2->GPT in zero-shot evaluation scripts

See merge request ADLR/megatron-lm!210
parents f34cc86b 5ff0f882
...@@ -61,7 +61,7 @@ if __name__ == '__main__': ...@@ -61,7 +61,7 @@ if __name__ == '__main__':
elif args.task in ['MNLI', 'QQP']: elif args.task in ['MNLI', 'QQP']:
from glue.finetune import main from glue.finetune import main
elif args.task in ['LAMBADA', 'WIKITEXT103']: elif args.task in ['LAMBADA', 'WIKITEXT103']:
from zeroshot_gpt2.evaluate import main from zeroshot_gpt.evaluate import main
else: else:
raise NotImplementedError('Task {} is not implemented.'.format( raise NotImplementedError('Task {} is not implemented.'.format(
args.task)) args.task))
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""GPT2 zero-shot evaluation.""" """GPT zero-shot evaluation."""
import math import math
...@@ -24,7 +24,7 @@ from megatron import print_rank_0, is_last_rank ...@@ -24,7 +24,7 @@ from megatron import print_rank_0, is_last_rank
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelLastStage, GPT2ModelIntermediateStage from megatron.model import GPTModel, GPTModelFirstStage, GPTModelLastStage, GPTModelIntermediateStage
from megatron.training import get_model, communicate from megatron.training import get_model, communicate
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from tasks.finetune_utils import build_data_loader from tasks.finetune_utils import build_data_loader
...@@ -47,18 +47,18 @@ def get_model_provider(eval_metric): ...@@ -47,18 +47,18 @@ def get_model_provider(eval_metric):
raise NotImplementedError('output type for {} evaluation metric ' raise NotImplementedError('output type for {} evaluation metric '
'is not supported.'.format(eval_metric)) 'is not supported.'.format(eval_metric))
print_rank_0('building GPT2 model ...') print_rank_0('building GPT model ...')
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0) model = GPTModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage( model = GPTModelLastStage(
parallel_output=parallel_output, num_tokentypes=0) parallel_output=parallel_output, num_tokentypes=0)
else: else:
model = GPT2ModelIntermediateStage(num_tokentypes=0) model = GPTModelIntermediateStage(num_tokentypes=0)
else: else:
model = GPT2Model(num_tokentypes=0, parallel_output=parallel_output) model = GPTModel(num_tokentypes=0, parallel_output=parallel_output)
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment