Commit 4635bdf7 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'readme_update_jan_2021' into 'main'

Readme update + change gpt2 to gpt

See merge request ADLR/megatron-lm!206
parents 86eb5bd8 152aab30
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample Generate GPT2"""
"""Sample Generate GPT"""
import os
import sys
......@@ -26,7 +26,10 @@ from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelLastStage, GPT2ModelIntermediateStage
from megatron.model import (GPTModel,
GPTModelFirstStage,
GPTModelLastStage,
GPTModelIntermediateStage)
from megatron.training import get_model
from megatron.text_generation_utils import generate_and_write_samples_unconditional
from megatron.text_generation_utils import generate_samples_input_from_file
......@@ -36,20 +39,20 @@ from megatron.text_generation_utils import generate_samples_interactive
def model_provider():
"""Build the model."""
print_rank_0('building GPT2 model ...')
print_rank_0('building GPT model ...')
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
model = GPTModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage(
model = GPTModelLastStage(
num_tokentypes=0, parallel_output=False)
else:
model = GPT2ModelIntermediateStage(
model = GPTModelIntermediateStage(
num_tokentypes=0)
else:
model = GPT2Model(num_tokentypes=0, parallel_output=False)
model = GPTModel(num_tokentypes=0, parallel_output=False)
return model
......
......@@ -108,8 +108,8 @@ def get_model(model_type):
if model_type == 'BERT':
from pretrain_bert import model_provider
elif model_type == 'GPT2':
from pretrain_gpt2 import model_provider
elif model_type == 'GPT':
from pretrain_gpt import model_provider
elif model_type == 'RACE':
from tasks.race.finetune import model_provider
elif model_type == ['MNLI', 'QQP']:
......@@ -177,7 +177,7 @@ def get_mp_merge_args(parser):
group = parser.add_argument_group(title='mp merge')
group.add_argument('--model-type', type=str, required=True,
choices=['BERT', 'GPT2', 'RACE', 'MNLI', 'QQP'],
choices=['BERT', 'GPT', 'RACE', 'MNLI', 'QQP'],
help='Type of the mdoel.')
return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment