Commit 0c151638 authored by Jared Casper's avatar Jared Casper Committed by Deepak Narayanan
Browse files

Add implementation for pipelined zeroshot GPT-2 evaluation

parent 3afcba6e
...@@ -20,12 +20,12 @@ import math ...@@ -20,12 +20,12 @@ import math
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0, is_last_rank
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.model import GPT2Model from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelLastStage, GPT2ModelIntermediateStage
from megatron.training import get_model from megatron.training import get_model, communicate
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from tasks.finetune_utils import build_data_loader from tasks.finetune_utils import build_data_loader
...@@ -48,6 +48,16 @@ def get_model_provider(eval_metric): ...@@ -48,6 +48,16 @@ def get_model_provider(eval_metric):
'is not supported.'.format(eval_metric)) 'is not supported.'.format(eval_metric))
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage(
parallel_output=parallel_output, num_tokentypes=0)
else:
model = GPT2ModelIntermediateStage(num_tokentypes=0)
else:
model = GPT2Model(num_tokentypes=0, parallel_output=parallel_output) model = GPT2Model(num_tokentypes=0, parallel_output=parallel_output)
return model return model
...@@ -83,9 +93,39 @@ def forward_step(batch, model, eval_metric): ...@@ -83,9 +93,39 @@ def forward_step(batch, model, eval_metric):
tokens, labels, attention_mask, position_ids, loss_mask = process_batch( tokens, labels, attention_mask, position_ids, loss_mask = process_batch(
batch) batch)
# Tell the model what our actual batch size will be
args = get_args()
args.micro_batch_size = len(labels)
# Forward model. # Forward model.
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward pass through the model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
if mpu.is_pipeline_last_stage():
output = model(tokens, position_ids, attention_mask)
else:
output = model(tokens, position_ids, attention_mask) output = model(tokens, position_ids, attention_mask)
else:
assert input_tensor is not None
output = model(input_tensor, attention_mask)
if not mpu.is_pipeline_last_stage():
communicate(tensor_send_next=output,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
return None
if mpu.is_pipeline_last_stage():
# For loss, return the unreduced loss. # For loss, return the unreduced loss.
if eval_metric == 'loss': if eval_metric == 'loss':
losses = mpu.vocab_parallel_cross_entropy( losses = mpu.vocab_parallel_cross_entropy(
...@@ -104,6 +144,7 @@ def forward_step(batch, model, eval_metric): ...@@ -104,6 +144,7 @@ def forward_step(batch, model, eval_metric):
raise NotImplementedError('forward method for evaluation metric {} ' raise NotImplementedError('forward method for evaluation metric {} '
'is not implemented.'.format(eval_metric)) 'is not implemented.'.format(eval_metric))
return None
def evaluate(data_loader, model, eval_metric): def evaluate(data_loader, model, eval_metric):
...@@ -123,6 +164,7 @@ def evaluate(data_loader, model, eval_metric): ...@@ -123,6 +164,7 @@ def evaluate(data_loader, model, eval_metric):
output = forward_step(batch, model, eval_metric) output = forward_step(batch, model, eval_metric)
# Reduce across processes. # Reduce across processes.
if mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(output, torch.distributed.all_reduce(output,
group=mpu.get_data_parallel_group()) group=mpu.get_data_parallel_group())
...@@ -138,6 +180,7 @@ def evaluate_and_print_results(task, data_loader, model, eval_metric): ...@@ -138,6 +180,7 @@ def evaluate_and_print_results(task, data_loader, model, eval_metric):
output = evaluate(data_loader, model, eval_metric) output = evaluate(data_loader, model, eval_metric)
string = ' validation results on {} | '.format(task) string = ' validation results on {} | '.format(task)
if is_last_rank():
if eval_metric == 'loss': if eval_metric == 'loss':
num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens
num_original_tokens = data_loader.dataset.num_original_tokens num_original_tokens = data_loader.dataset.num_original_tokens
...@@ -162,9 +205,9 @@ def evaluate_and_print_results(task, data_loader, model, eval_metric): ...@@ -162,9 +205,9 @@ def evaluate_and_print_results(task, data_loader, model, eval_metric):
'implemented yet.'.format(eval_metric)) 'implemented yet.'.format(eval_metric))
length = len(string) + 1 length = len(string) + 1
print_rank_0('-' * length) print('-' * length)
print_rank_0(string) print(string)
print_rank_0('-' * length) print('-' * length)
def main(): def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment