Commit adbba962 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Merge branch 'main' into vision_transformer

parents a75f1783 be473a5b
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain GPT2"""
"""Pretrain GPT"""
import torch
......@@ -22,8 +22,11 @@ from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron.data.gpt2_dataset import build_train_valid_test_datasets
from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import (GPTModel,
GPTModelFirstStage,
GPTModelIntermediateStage,
GPTModelLastStage)
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
......@@ -31,20 +34,20 @@ from megatron.utils import average_losses_across_data_parallel_group
def model_provider():
"""Build the model."""
print_rank_0('building GPT2 model ...')
print_rank_0('building GPT model ...')
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
model = GPTModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage(
model = GPTModelLastStage(
num_tokentypes=0, parallel_output=True)
else:
model = GPT2ModelIntermediateStage(
model = GPTModelIntermediateStage(
num_tokentypes=0)
else:
model = GPT2Model(num_tokentypes=0, parallel_output=True)
model = GPTModel(num_tokentypes=0, parallel_output=True)
return model
......@@ -124,7 +127,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for GPT2 ...')
'for GPT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
......@@ -133,7 +136,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
print_rank_0("> finished creating GPT2 datasets ...")
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
......
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample Generate GPT2"""
"""Sample Generate GPT"""
import os
import sys
......@@ -26,7 +26,10 @@ from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelLastStage, GPT2ModelIntermediateStage
from megatron.model import (GPTModel,
GPTModelFirstStage,
GPTModelLastStage,
GPTModelIntermediateStage)
from megatron.training import get_model
from megatron.text_generation_utils import generate_and_write_samples_unconditional
from megatron.text_generation_utils import generate_samples_input_from_file
......@@ -36,20 +39,20 @@ from megatron.text_generation_utils import generate_samples_interactive
def model_provider():
"""Build the model."""
print_rank_0('building GPT2 model ...')
print_rank_0('building GPT model ...')
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
model = GPTModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage(
model = GPTModelLastStage(
num_tokentypes=0, parallel_output=False)
else:
model = GPT2ModelIntermediateStage(
model = GPTModelIntermediateStage(
num_tokentypes=0)
else:
model = GPT2Model(num_tokentypes=0, parallel_output=False)
model = GPTModel(num_tokentypes=0, parallel_output=False)
return model
......
......@@ -108,8 +108,8 @@ def get_model(model_type):
if model_type == 'BERT':
from pretrain_bert import model_provider
elif model_type == 'GPT2':
from pretrain_gpt2 import model_provider
elif model_type == 'GPT':
from pretrain_gpt import model_provider
elif model_type == 'RACE':
from tasks.race.finetune import model_provider
elif model_type == ['MNLI', 'QQP']:
......@@ -177,7 +177,7 @@ def get_mp_merge_args(parser):
group = parser.add_argument_group(title='mp merge')
group.add_argument('--model-type', type=str, required=True,
choices=['BERT', 'GPT2', 'RACE', 'MNLI', 'QQP'],
choices=['BERT', 'GPT', 'RACE', 'MNLI', 'QQP'],
help='Type of the mdoel.')
return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment