Commit b7071993 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'skyw/fix_zeroshot_eval_script' into 'main'

remove mpu dependency in zeroshot script

See merge request ADLR/megatron-lm!493
parents e1c334b0 8ed3887a
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import print_rank_0, is_last_rank from megatron import print_rank_0, is_last_rank
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.core import mpu from megatron.core import parallel_state, tensor_parallel
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.model import GPTModel from megatron.model import GPTModel
from megatron.training import get_model from megatron.training import get_model
...@@ -90,10 +90,10 @@ def forward_step(batch, model, eval_metric): ...@@ -90,10 +90,10 @@ def forward_step(batch, model, eval_metric):
send_forward(output) send_forward(output)
if mpu.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
# For loss, return the unreduced loss. # For loss, return the unreduced loss.
if eval_metric == 'loss': if eval_metric == 'loss':
losses = mpu.tensor_parallel.vocab_parallel_cross_entropy( losses = tensor_parallel.vocab_parallel_cross_entropy(
output.contiguous().float(), labels.contiguous()) output.contiguous().float(), labels.contiguous())
loss = torch.sum( loss = torch.sum(
losses.view(-1) * loss_mask.contiguous().view(-1).float()) losses.view(-1) * loss_mask.contiguous().view(-1).float())
...@@ -129,9 +129,9 @@ def evaluate(data_loader, model, eval_metric): ...@@ -129,9 +129,9 @@ def evaluate(data_loader, model, eval_metric):
output = forward_step(batch, model, eval_metric) output = forward_step(batch, model, eval_metric)
# Reduce across processes. # Reduce across processes.
if mpu.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
torch.distributed.all_reduce(output, torch.distributed.all_reduce(output,
group=mpu.get_data_parallel_group()) group=parallel_state.get_data_parallel_group())
total_output += output total_output += output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment