Commit 9b174da8 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Merge branch 'pipeline_parallel_main' into 'main'

Pipeline parallelism and inter-layer model parallelism implementation

See merge request ADLR/megatron-lm!159
parents 3aacd955 6e83649f
......@@ -18,7 +18,8 @@
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.model.classification import Classification
from megatron import mpu
from megatron.model.classification import Classification, ClassificationFirstStage, ClassificationIntermediateStage, ClassificationLastStage
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
......@@ -44,8 +45,21 @@ def glue_classification(num_classes, Dataset,
print_rank_0('building classification model for {} ...'.format(
args.task))
return Classification(num_classes=num_classes, num_tokentypes=2)
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = ClassificationFirstStage(
num_classes=num_classes, num_tokentypes=2)
elif mpu.is_pipeline_last_stage():
model = ClassificationLastStage(
num_classes=num_classes, num_tokentypes=2)
else:
model = ClassificationIntermediateStage(
num_classes=num_classes, num_tokentypes=2)
else:
model = Classification(num_classes=num_classes, num_tokentypes=2)
return model
def metrics_func_provider():
"""Privde metrics callback function."""
......
......@@ -39,6 +39,8 @@ class RaceDataset(Dataset):
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
self.sample_multiplier = NUM_CHOICES
def __len__(self):
return len(self.samples)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -136,7 +136,7 @@ def get_args():
# some default/dummy values for the tokenizer
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.model_parallel_size = 1
args.tensor_model_parallel_size = 1
return args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment