Commit 64a83fb5 authored by Jared Casper's avatar Jared Casper
Browse files

Fixing text generation and zeroshot eval and addressing comments.

parent b938ec51
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
from .distributed import * from .distributed import DistributedDataParallel
from .bert_model import BertModel from .bert_model import BertModel
from .gpt_model import GPTModel from .gpt_model import GPTModel
from .language_model import get_language_model from .language_model import get_language_model
......
...@@ -22,7 +22,9 @@ from megatron import get_num_microbatches ...@@ -22,7 +22,9 @@ from megatron import get_num_microbatches
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import p2p_communication from megatron import p2p_communication
from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def get_forward_backward_func(): def get_forward_backward_func():
args = get_args() args = get_args()
...@@ -46,8 +48,9 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -46,8 +48,9 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
timers = get_timers() timers = get_timers()
timers('forward-compute').start() timers('forward-compute').start()
# TODO unwrapped_model = unwrap_model(
model.module.module.set_input_tensor(input_tensor) model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model) output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor) output_tensor = loss_func(output_tensor)
......
...@@ -26,9 +26,14 @@ import torch.nn.functional as F ...@@ -26,9 +26,14 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.p2p_communication import recv_forward, send_forward from megatron.p2p_communication import recv_forward, send_forward
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def get_batch(context_tokens): def get_batch(context_tokens):
"""Generate batch from context tokens.""" """Generate batch from context tokens."""
args = get_args() args = get_args()
...@@ -403,7 +408,9 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, ...@@ -403,7 +408,9 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
input_tensor = recv_forward() input_tensor = recv_forward()
# Forward pass through the model. # Forward pass through the model.
model.set_input_tensor(input_tensor) unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
output_tensor = model(tokens, position_ids, attention_mask, output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids, tokentype_ids=tokentype_ids,
layer_past=layer_past, layer_past=layer_past,
......
...@@ -231,6 +231,9 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -231,6 +231,9 @@ def finetune(train_valid_datasets_provider, model_provider,
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
assert args.rampup_batch_size is None, \
'batch size scaling is not supported for finetuning'
# Train and validation data loaders. # Train and validation data loaders.
timers('train/valid/test dataset/dataloder').start() timers('train/valid/test dataset/dataloder').start()
if args.epochs > 0: if args.epochs > 0:
......
...@@ -24,20 +24,24 @@ from megatron import print_rank_0, is_last_rank ...@@ -24,20 +24,24 @@ from megatron import print_rank_0, is_last_rank
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.model import GPTModel, GPTModelFirstStage, GPTModelLastStage, GPTModelIntermediateStage from megatron.model import GPTModel
from megatron.training import get_model from megatron.training import get_model
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.p2p_communication import recv_forward, send_forward from megatron.p2p_communication import recv_forward, send_forward
from tasks.finetune_utils import build_data_loader from tasks.finetune_utils import build_data_loader
from .datasets import build_dataset from .datasets import build_dataset
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def get_model_provider(eval_metric): def get_model_provider(eval_metric):
"""Based on evaluation metric set the parallel-output flag and """Based on evaluation metric set the parallel-output flag and
return the model provider.""" return the model provider."""
def model_provider(): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
if eval_metric == 'loss': if eval_metric == 'loss':
...@@ -49,17 +53,8 @@ def get_model_provider(eval_metric): ...@@ -49,17 +53,8 @@ def get_model_provider(eval_metric):
'is not supported.'.format(eval_metric)) 'is not supported.'.format(eval_metric))
print_rank_0('building GPT model ...') print_rank_0('building GPT model ...')
if mpu.get_pipeline_model_parallel_world_size() > 1: model = GPTModel(num_tokentypes=0, parallel_output=parallel_output,
# Determine model based on position of stage in pipeline. pre_process=pre_process, post_process=post_process)
if mpu.is_pipeline_first_stage():
model = GPTModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPTModelLastStage(
parallel_output=parallel_output, num_tokentypes=0)
else:
model = GPTModelIntermediateStage(num_tokentypes=0)
else:
model = GPTModel(num_tokentypes=0, parallel_output=parallel_output)
return model return model
...@@ -98,19 +93,13 @@ def forward_step(batch, model, eval_metric): ...@@ -98,19 +93,13 @@ def forward_step(batch, model, eval_metric):
args = get_args() args = get_args()
args.micro_batch_size = len(labels) args.micro_batch_size = len(labels)
# Forward model.
input_tensor = recv_forward() input_tensor = recv_forward()
# Forward pass through the model. # Forward pass through the model.
if mpu.is_pipeline_first_stage(): unwrapped_model = unwrap_model(
assert input_tensor is None model, (torchDDP, LocalDDP, Float16Module))
if mpu.is_pipeline_last_stage(): unwrapped_model.set_input_tensor(input_tensor)
output = model(tokens, position_ids, attention_mask) output = model(tokens, position_ids, attention_mask)
else:
output = model(tokens, position_ids, attention_mask)
else:
assert input_tensor is not None
output = model(input_tensor, attention_mask)
send_forward(output) send_forward(output)
......
...@@ -26,33 +26,19 @@ from megatron import get_tokenizer ...@@ -26,33 +26,19 @@ from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.model import (GPTModel, from megatron.model import GPTModel
GPTModelFirstStage,
GPTModelLastStage,
GPTModelIntermediateStage)
from megatron.training import get_model from megatron.training import get_model
from megatron.text_generation_utils import generate_and_write_samples_unconditional from megatron.text_generation_utils import generate_and_write_samples_unconditional
from megatron.text_generation_utils import generate_samples_input_from_file from megatron.text_generation_utils import generate_samples_input_from_file
from megatron.text_generation_utils import generate_samples_interactive from megatron.text_generation_utils import generate_samples_interactive
def model_provider(): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
print_rank_0('building GPT model ...') print_rank_0('building GPT model ...')
args = get_args() model = GPTModel(num_tokentypes=0, parallel_output=False,
if mpu.get_pipeline_model_parallel_world_size() > 1: pre_process=pre_process, post_process=post_process)
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPTModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPTModelLastStage(
num_tokentypes=0, parallel_output=False)
else:
model = GPTModelIntermediateStage(
num_tokentypes=0)
else:
model = GPTModel(num_tokentypes=0, parallel_output=False)
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment