"git@developer.sourcefind.cn:change/sglang.git" did not exist on "62c673c46f3694c66895d29800b796c561aa1873"
Commit 9b174da8 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Merge branch 'pipeline_parallel_main' into 'main'

Pipeline parallelism and inter-layer model parallelism implementation

See merge request ADLR/megatron-lm!159
parents 3aacd955 6e83649f
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.model.classification import Classification from megatron import mpu
from megatron.model.classification import Classification, ClassificationFirstStage, ClassificationIntermediateStage, ClassificationLastStage
from tasks.eval_utils import accuracy_func_provider from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune from tasks.finetune_utils import finetune
...@@ -44,8 +45,21 @@ def glue_classification(num_classes, Dataset, ...@@ -44,8 +45,21 @@ def glue_classification(num_classes, Dataset,
print_rank_0('building classification model for {} ...'.format( print_rank_0('building classification model for {} ...'.format(
args.task)) args.task))
if mpu.get_pipeline_model_parallel_world_size() > 1:
return Classification(num_classes=num_classes, num_tokentypes=2) # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = ClassificationFirstStage(
num_classes=num_classes, num_tokentypes=2)
elif mpu.is_pipeline_last_stage():
model = ClassificationLastStage(
num_classes=num_classes, num_tokentypes=2)
else:
model = ClassificationIntermediateStage(
num_classes=num_classes, num_tokentypes=2)
else:
model = Classification(num_classes=num_classes, num_tokentypes=2)
return model
def metrics_func_provider(): def metrics_func_provider():
"""Privde metrics callback function.""" """Privde metrics callback function."""
......
...@@ -39,6 +39,8 @@ class RaceDataset(Dataset): ...@@ -39,6 +39,8 @@ class RaceDataset(Dataset):
print_rank_0(' >> total number of samples: {}'.format( print_rank_0(' >> total number of samples: {}'.format(
len(self.samples))) len(self.samples)))
self.sample_multiplier = NUM_CHOICES
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -136,7 +136,7 @@ def get_args(): ...@@ -136,7 +136,7 @@ def get_args():
# some default/dummy values for the tokenizer # some default/dummy values for the tokenizer
args.rank = 0 args.rank = 0
args.make_vocab_size_divisible_by = 128 args.make_vocab_size_divisible_by = 128
args.model_parallel_size = 1 args.tensor_model_parallel_size = 1
return args return args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment