Commit 9b174da8 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Merge branch 'pipeline_parallel_main' into 'main'

Pipeline parallelism and inter-layer model parallelism implementation

See merge request ADLR/megatron-lm!159
parents 3aacd955 6e83649f
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.model.classification import Classification from megatron import mpu
from megatron.model.classification import Classification, ClassificationFirstStage, ClassificationIntermediateStage, ClassificationLastStage
from tasks.eval_utils import accuracy_func_provider from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune from tasks.finetune_utils import finetune
...@@ -44,8 +45,21 @@ def glue_classification(num_classes, Dataset, ...@@ -44,8 +45,21 @@ def glue_classification(num_classes, Dataset,
print_rank_0('building classification model for {} ...'.format( print_rank_0('building classification model for {} ...'.format(
args.task)) args.task))
if mpu.get_pipeline_model_parallel_world_size() > 1:
return Classification(num_classes=num_classes, num_tokentypes=2) # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = ClassificationFirstStage(
num_classes=num_classes, num_tokentypes=2)
elif mpu.is_pipeline_last_stage():
model = ClassificationLastStage(
num_classes=num_classes, num_tokentypes=2)
else:
model = ClassificationIntermediateStage(
num_classes=num_classes, num_tokentypes=2)
else:
model = Classification(num_classes=num_classes, num_tokentypes=2)
return model
def metrics_func_provider(): def metrics_func_provider():
"""Privde metrics callback function.""" """Privde metrics callback function."""
......
...@@ -39,6 +39,8 @@ class RaceDataset(Dataset): ...@@ -39,6 +39,8 @@ class RaceDataset(Dataset):
print_rank_0(' >> total number of samples: {}'.format( print_rank_0(' >> total number of samples: {}'.format(
len(self.samples))) len(self.samples)))
self.sample_multiplier = NUM_CHOICES
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.model.multiple_choice import MultipleChoice from megatron import mpu
from megatron.model.multiple_choice import MultipleChoice, MultipleChoiceFirstStage, MultipleChoiceIntermediateStage, MultipleChoiceLastStage
from tasks.eval_utils import accuracy_func_provider from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune from tasks.finetune_utils import finetune
from tasks.race.data import RaceDataset from tasks.race.data import RaceDataset
...@@ -41,8 +42,18 @@ def model_provider(): ...@@ -41,8 +42,18 @@ def model_provider():
"""Build the model.""" """Build the model."""
print_rank_0('building multichoice model for RACE ...') print_rank_0('building multichoice model for RACE ...')
if mpu.get_pipeline_model_parallel_world_size() > 1:
return MultipleChoice(num_tokentypes=2) # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = MultipleChoiceFirstStage(num_tokentypes=2)
elif mpu.is_pipeline_last_stage():
model = MultipleChoiceLastStage(num_tokentypes=2)
else:
model = MultipleChoiceIntermediateStage(num_tokentypes=2)
else:
model = MultipleChoice(num_tokentypes=2)
return model
def metrics_func_provider(): def metrics_func_provider():
......
...@@ -20,12 +20,12 @@ import math ...@@ -20,12 +20,12 @@ import math
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0, is_last_rank
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.model import GPT2Model from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelLastStage, GPT2ModelIntermediateStage
from megatron.training import get_model from megatron.training import get_model, communicate
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from tasks.finetune_utils import build_data_loader from tasks.finetune_utils import build_data_loader
...@@ -48,7 +48,17 @@ def get_model_provider(eval_metric): ...@@ -48,7 +48,17 @@ def get_model_provider(eval_metric):
'is not supported.'.format(eval_metric)) 'is not supported.'.format(eval_metric))
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=parallel_output) if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage(
parallel_output=parallel_output, num_tokentypes=0)
else:
model = GPT2ModelIntermediateStage(num_tokentypes=0)
else:
model = GPT2Model(num_tokentypes=0, parallel_output=parallel_output)
return model return model
...@@ -83,27 +93,58 @@ def forward_step(batch, model, eval_metric): ...@@ -83,27 +93,58 @@ def forward_step(batch, model, eval_metric):
tokens, labels, attention_mask, position_ids, loss_mask = process_batch( tokens, labels, attention_mask, position_ids, loss_mask = process_batch(
batch) batch)
# Tell the model what our actual batch size will be
args = get_args()
args.micro_batch_size = len(labels)
# Forward model. # Forward model.
output = model(tokens, position_ids, attention_mask) if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# For loss, return the unreduced loss. # Forward pass through the model.
if eval_metric == 'loss': if mpu.is_pipeline_first_stage():
losses = mpu.vocab_parallel_cross_entropy( assert input_tensor is None
output.contiguous().float(), labels.contiguous()) if mpu.is_pipeline_last_stage():
loss = torch.sum( output = model(tokens, position_ids, attention_mask)
losses.view(-1) * loss_mask.contiguous().view(-1).float()) else:
return loss output = model(tokens, position_ids, attention_mask)
else:
assert input_tensor is not None
output = model(input_tensor, attention_mask)
if not mpu.is_pipeline_last_stage():
communicate(tensor_send_next=output,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
return None
if mpu.is_pipeline_last_stage():
# For loss, return the unreduced loss.
if eval_metric == 'loss':
losses = mpu.vocab_parallel_cross_entropy(
output.contiguous().float(), labels.contiguous())
loss = torch.sum(
losses.view(-1) * loss_mask.contiguous().view(-1).float())
return loss
# For accuracy, return the number of correctly predicted samples. # For accuracy, return the number of correctly predicted samples.
if eval_metric == 'accuracy': if eval_metric == 'accuracy':
outputs = torch.argmax(output, -1) outputs = torch.argmax(output, -1)
correct = (outputs == labels).float() correct = (outputs == labels).float()
correct[(1 - loss_mask).bool()] = 1 correct[(1 - loss_mask).bool()] = 1
correct = correct.prod(-1) correct = correct.prod(-1)
return correct.sum() return correct.sum()
raise NotImplementedError('forward method for evaluation metric {} ' raise NotImplementedError('forward method for evaluation metric {} '
'is not implemented.'.format(eval_metric)) 'is not implemented.'.format(eval_metric))
return None
def evaluate(data_loader, model, eval_metric): def evaluate(data_loader, model, eval_metric):
...@@ -123,10 +164,11 @@ def evaluate(data_loader, model, eval_metric): ...@@ -123,10 +164,11 @@ def evaluate(data_loader, model, eval_metric):
output = forward_step(batch, model, eval_metric) output = forward_step(batch, model, eval_metric)
# Reduce across processes. # Reduce across processes.
torch.distributed.all_reduce(output, if mpu.is_pipeline_last_stage():
group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(output,
group=mpu.get_data_parallel_group())
total_output += output total_output += output
return total_output return total_output
...@@ -138,33 +180,34 @@ def evaluate_and_print_results(task, data_loader, model, eval_metric): ...@@ -138,33 +180,34 @@ def evaluate_and_print_results(task, data_loader, model, eval_metric):
output = evaluate(data_loader, model, eval_metric) output = evaluate(data_loader, model, eval_metric)
string = ' validation results on {} | '.format(task) string = ' validation results on {} | '.format(task)
if eval_metric == 'loss': if is_last_rank():
num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens if eval_metric == 'loss':
num_original_tokens = data_loader.dataset.num_original_tokens num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens
val_loss = output / (num_tokenized_tokens - 1) num_original_tokens = data_loader.dataset.num_original_tokens
ppl = math.exp(min(20, val_loss)) val_loss = output / (num_tokenized_tokens - 1)
token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1) ppl = math.exp(min(20, val_loss))
adjusted_ppl = math.exp(min(20, val_loss * token_ratio)) token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1)
string += 'avg loss: {:.4E} | '.format(val_loss) adjusted_ppl = math.exp(min(20, val_loss * token_ratio))
string += 'ppl: {:.4E} | '.format(ppl) string += 'avg loss: {:.4E} | '.format(val_loss)
string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl) string += 'ppl: {:.4E} | '.format(ppl)
string += 'token ratio: {} |'.format(token_ratio) string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl)
string += 'token ratio: {} |'.format(token_ratio)
elif eval_metric == 'accuracy':
num_examples = len(data_loader.dataset)
acc = output / num_examples
string += 'number correct: {:.4E} | '.format(output)
string += 'total examples: {:.4E} | '.format(num_examples)
string += 'avg accuracy: {:.4E}'.format(acc)
else: elif eval_metric == 'accuracy':
raise NotImplementedError('evaluation method for {} metric is not ' num_examples = len(data_loader.dataset)
'implemented yet.'.format(eval_metric)) acc = output / num_examples
string += 'number correct: {:.4E} | '.format(output)
string += 'total examples: {:.4E} | '.format(num_examples)
string += 'avg accuracy: {:.4E}'.format(acc)
else:
raise NotImplementedError('evaluation method for {} metric is not '
'implemented yet.'.format(eval_metric))
length = len(string) + 1 length = len(string) + 1
print_rank_0('-' * length) print('-' * length)
print_rank_0(string) print(string)
print_rank_0('-' * length) print('-' * length)
def main(): def main():
...@@ -186,7 +229,7 @@ def main(): ...@@ -186,7 +229,7 @@ def main():
# Data stuff. # Data stuff.
dataset = build_dataset(args.task) dataset = build_dataset(args.task)
dataloader = build_data_loader(dataset, args.batch_size, dataloader = build_data_loader(dataset, args.micro_batch_size,
args.num_workers, drop_last=False) args.num_workers, drop_last=False)
# Run evaluation. # Run evaluation.
......
...@@ -23,9 +23,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ...@@ -23,9 +23,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelLastStage, GPT2ModelIntermediateStage
from megatron.training import get_model from megatron.training import get_model
from megatron.text_generation_utils import generate_and_write_samples_unconditional from megatron.text_generation_utils import generate_and_write_samples_unconditional
from megatron.text_generation_utils import generate_samples_input_from_file from megatron.text_generation_utils import generate_samples_input_from_file
...@@ -36,7 +37,19 @@ def model_provider(): ...@@ -36,7 +37,19 @@ def model_provider():
"""Build the model.""" """Build the model."""
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=False) args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage(
num_tokentypes=0, parallel_output=False)
else:
model = GPT2ModelIntermediateStage(
num_tokentypes=0)
else:
model = GPT2Model(num_tokentypes=0, parallel_output=False)
return model return model
...@@ -86,7 +99,7 @@ def main(): ...@@ -86,7 +99,7 @@ def main():
# Generate samples. # Generate samples.
if args.num_samples == 0: if args.num_samples == 0:
args.batch_size = 1 args.micro_batch_size = 1
if args.sample_input_file != None: if args.sample_input_file != None:
generate_samples_input_from_file(model) generate_samples_input_from_file(model)
else: else:
......
...@@ -188,18 +188,18 @@ def main(): ...@@ -188,18 +188,18 @@ def main():
# Args # Args
args = _parse_args(extra_args_provider=get_mp_merge_args) args = _parse_args(extra_args_provider=get_mp_merge_args)
model_type = args.model_type model_type = args.model_type
orig_model_parallel_size = args.model_parallel_size orig_tensor_model_parallel_size = args.tensor_model_parallel_size
args.model_parallel_size = 1 args.tensor_model_parallel_size = 1
tokenizer = rebuild_tokenizer(args) tokenizer = rebuild_tokenizer(args)
print('\n merging model parallel partitions ...') print('\n merging model parallel partitions ...')
print(' > number of partitions: {}'.format(orig_model_parallel_size)) print(' > number of partitions: {}'.format(orig_tensor_model_parallel_size))
print(' > checkpoint path: {}'.format(args.load)) print(' > checkpoint path: {}'.format(args.load))
print(' > model parameters:') print(' > model parameters:')
print(' number of tokens ................ {} '.format( print(' number of tokens ................ {} '.format(
tokenizer.vocab_size)) tokenizer.vocab_size))
print(' number of layers ................ {}'.format(args.num_layers)) print(' number of layers ................ {}'.format(args.num_layers))
print(' hidden sise ..................... {}'.format(args.hidden_size)) print(' hidden size ..................... {}'.format(args.hidden_size))
print(' number of attention heads ....... {}'.format( print(' number of attention heads ....... {}'.format(
args.num_attention_heads)) args.num_attention_heads))
print(' maximum position embeddings ..... {}'.format( print(' maximum position embeddings ..... {}'.format(
...@@ -207,18 +207,18 @@ def main(): ...@@ -207,18 +207,18 @@ def main():
# Full model. # Full model.
print('> building the full model ...') print('> building the full model ...')
mpu.initialize.set_model_parallel_world_size(1) mpu.initialize.set_tensor_model_parallel_world_size(1)
mpu.initialize.set_model_parallel_rank(0) mpu.initialize.set_tensor_model_parallel_rank(0)
merged_model = get_model(model_type) merged_model = get_model(model_type)
# Build and load partitions. # Build and load partitions.
partitions = [] partitions = []
iteration = 0 iteration = 0
args.model_parallel_size = orig_model_parallel_size args.tensor_model_parallel_size = orig_tensor_model_parallel_size
tokenizer = rebuild_tokenizer(args) tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_model_parallel_world_size(args.model_parallel_size) mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
for rank in range(args.model_parallel_size): for rank in range(args.tensor_model_parallel_size):
mpu.initialize.set_model_parallel_rank(rank) mpu.initialize.set_tensor_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load) checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
print('> loading {} ...'.format(checkpoint_name)) print('> loading {} ...'.format(checkpoint_name))
model_ = get_model(model_type) model_ = get_model(model_type)
...@@ -248,7 +248,7 @@ def main(): ...@@ -248,7 +248,7 @@ def main():
rank, partition_param.dtype, list(partition_param.size()))) rank, partition_param.dtype, list(partition_param.size())))
# For the non-parallel parameters, simply copy the rank 0 values. # For the non-parallel parameters, simply copy the rank 0 values.
if not hasattr(merged_param, 'model_parallel'): if not hasattr(merged_param, 'tensor_model_parallel'):
print(' none-parallel parameter, simple copy from rank 0') print(' none-parallel parameter, simple copy from rank 0')
with torch.no_grad(): with torch.no_grad():
merged_param.data.copy_(partitions_param[0].data) merged_param.data.copy_(partitions_param[0].data)
...@@ -267,8 +267,8 @@ def main(): ...@@ -267,8 +267,8 @@ def main():
# Save the model. # Save the model.
args.model_parallel_size = 1 args.tensor_model_parallel_size = 1
mpu.initialize.set_model_parallel_rank(0) mpu.initialize.set_tensor_model_parallel_rank(0)
sd = {} sd = {}
sd['model'] = merged_model.state_dict_for_save_checkpoint() sd['model'] = merged_model.state_dict_for_save_checkpoint()
sd['iteration'] = iteration sd['iteration'] = iteration
......
...@@ -136,7 +136,7 @@ def get_args(): ...@@ -136,7 +136,7 @@ def get_args():
# some default/dummy values for the tokenizer # some default/dummy values for the tokenizer
args.rank = 0 args.rank = 0
args.make_vocab_size_divisible_by = 128 args.make_vocab_size_divisible_by = 128
args.model_parallel_size = 1 args.tensor_model_parallel_size = 1
return args return args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment