Commit adbba962 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Merge branch 'main' into vision_transformer

parents a75f1783 be473a5b
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Pretrain GPT2""" """Pretrain GPT"""
import torch import torch
...@@ -22,8 +22,11 @@ from megatron import print_rank_0 ...@@ -22,8 +22,11 @@ from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.data.gpt2_dataset import build_train_valid_test_datasets from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage from megatron.model import (GPTModel,
GPTModelFirstStage,
GPTModelIntermediateStage,
GPTModelLastStage)
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
...@@ -31,20 +34,20 @@ from megatron.utils import average_losses_across_data_parallel_group ...@@ -31,20 +34,20 @@ from megatron.utils import average_losses_across_data_parallel_group
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
print_rank_0('building GPT2 model ...') print_rank_0('building GPT model ...')
args = get_args() args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0) model = GPTModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage( model = GPTModelLastStage(
num_tokentypes=0, parallel_output=True) num_tokentypes=0, parallel_output=True)
else: else:
model = GPT2ModelIntermediateStage( model = GPTModelIntermediateStage(
num_tokentypes=0) num_tokentypes=0)
else: else:
model = GPT2Model(num_tokentypes=0, parallel_output=True) model = GPTModel(num_tokentypes=0, parallel_output=True)
return model return model
...@@ -124,7 +127,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -124,7 +127,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
args = get_args() args = get_args()
print_rank_0('> building train, validation, and test datasets ' print_rank_0('> building train, validation, and test datasets '
'for GPT2 ...') 'for GPT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets( train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path, data_prefix=args.data_path,
data_impl=args.data_impl, data_impl=args.data_impl,
...@@ -133,7 +136,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -133,7 +136,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
seq_length=args.seq_length, seq_length=args.seq_length,
seed=args.seed, seed=args.seed,
skip_warmup=(not args.mmap_warmup)) skip_warmup=(not args.mmap_warmup))
print_rank_0("> finished creating GPT2 datasets ...") print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds return train_ds, valid_ds, test_ds
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Sample Generate GPT2""" """Sample Generate GPT"""
import os import os
import sys import sys
...@@ -26,7 +26,10 @@ from megatron import get_tokenizer ...@@ -26,7 +26,10 @@ from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelLastStage, GPT2ModelIntermediateStage from megatron.model import (GPTModel,
GPTModelFirstStage,
GPTModelLastStage,
GPTModelIntermediateStage)
from megatron.training import get_model from megatron.training import get_model
from megatron.text_generation_utils import generate_and_write_samples_unconditional from megatron.text_generation_utils import generate_and_write_samples_unconditional
from megatron.text_generation_utils import generate_samples_input_from_file from megatron.text_generation_utils import generate_samples_input_from_file
...@@ -36,20 +39,20 @@ from megatron.text_generation_utils import generate_samples_interactive ...@@ -36,20 +39,20 @@ from megatron.text_generation_utils import generate_samples_interactive
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
print_rank_0('building GPT2 model ...') print_rank_0('building GPT model ...')
args = get_args() args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0) model = GPTModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage( model = GPTModelLastStage(
num_tokentypes=0, parallel_output=False) num_tokentypes=0, parallel_output=False)
else: else:
model = GPT2ModelIntermediateStage( model = GPTModelIntermediateStage(
num_tokentypes=0) num_tokentypes=0)
else: else:
model = GPT2Model(num_tokentypes=0, parallel_output=False) model = GPTModel(num_tokentypes=0, parallel_output=False)
return model return model
......
...@@ -108,8 +108,8 @@ def get_model(model_type): ...@@ -108,8 +108,8 @@ def get_model(model_type):
if model_type == 'BERT': if model_type == 'BERT':
from pretrain_bert import model_provider from pretrain_bert import model_provider
elif model_type == 'GPT2': elif model_type == 'GPT':
from pretrain_gpt2 import model_provider from pretrain_gpt import model_provider
elif model_type == 'RACE': elif model_type == 'RACE':
from tasks.race.finetune import model_provider from tasks.race.finetune import model_provider
elif model_type == ['MNLI', 'QQP']: elif model_type == ['MNLI', 'QQP']:
...@@ -177,7 +177,7 @@ def get_mp_merge_args(parser): ...@@ -177,7 +177,7 @@ def get_mp_merge_args(parser):
group = parser.add_argument_group(title='mp merge') group = parser.add_argument_group(title='mp merge')
group.add_argument('--model-type', type=str, required=True, group.add_argument('--model-type', type=str, required=True,
choices=['BERT', 'GPT2', 'RACE', 'MNLI', 'QQP'], choices=['BERT', 'GPT', 'RACE', 'MNLI', 'QQP'],
help='Type of the mdoel.') help='Type of the mdoel.')
return parser return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment