"vscode:/vscode.git/clone" did not exist on "091da63d1c7bfb77df1be0dfe9c8e622ff6ab66f"
Commit 9350ee08 authored by Neel Kant's avatar Neel Kant
Browse files

Merge staging into ict-stable

parents 9238c57a 63262827
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron global variables."""
import os
import sys
import time
import torch
from megatron.tokenizer import build_tokenizer
from .arguments import parse_args
_GLOBAL_ARGS = None
_GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None
def get_args():
"""Return arguments."""
_ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
return _GLOBAL_ARGS
def get_tokenizer():
"""Return tokenizer."""
_ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
return _GLOBAL_TOKENIZER
def get_tensorboard_writer():
"""Return tensorboard writer. It can be None so no need
to check if it is initialized."""
return _GLOBAL_TENSORBOARD_WRITER
def get_adlr_autoresume():
"""ADLR autoresume object. It can be None so no need
to check if it is initialized."""
return _GLOBAL_ADLR_AUTORESUME
def get_timers():
"""Return timers."""
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
return _GLOBAL_TIMERS
def set_global_variables(extra_args_provider=None, args_defaults={}):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults)
_build_tokenizer(args)
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
_set_timers()
def _parse_args(extra_args_provider=None, defaults={}):
"""Parse entire arguments."""
global _GLOBAL_ARGS
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults)
return _GLOBAL_ARGS
def _build_tokenizer(args):
"""Initialize tokenizer."""
global _GLOBAL_TOKENIZER
_ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
_GLOBAL_TOKENIZER = build_tokenizer(args)
def _set_tensorboard_writer(args):
"""Set tensorboard writer."""
global _GLOBAL_TENSORBOARD_WRITER
_ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
'tensorboard writer')
if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == 0:
try:
from torch.utils.tensorboard import SummaryWriter
print('> setting tensorboard ...')
_GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
log_dir=args.tensorboard_dir)
except ModuleNotFoundError:
print('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.', flush=True)
def _set_adlr_autoresume(args):
"""Initialize ADLR autoresume."""
global _GLOBAL_ADLR_AUTORESUME
_ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume')
if args.adlr_autoresume:
if args.rank == 0:
print('enabling autoresume ...', flush=True)
sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
try:
from userlib.auto_resume import AutoResume
except:
print('ADLR autoresume is not available, exiting ...')
sys.exit()
_GLOBAL_ADLR_AUTORESUME = AutoResume
def _set_timers():
"""Initialize timers."""
global _GLOBAL_TIMERS
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
_GLOBAL_TIMERS = Timers()
def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is not None, '{} is not initialized.'.format(name)
def _ensure_var_is_not_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is None, '{} is already initialized.'.format(name)
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class Timers:
"""Group of timers."""
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '_time', value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(string, flush=True)
else:
print(string, flush=True)
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron initialization."""
import random
import os
import numpy as np
import torch
from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
from megatron import mpu
from megatron.global_vars import set_global_variables
def initialize_megatron(extra_args_provider=None, args_defaults={}):
"""Set global variables, initialize distributed, and
set autoresume and random seeds."""
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
# Pytorch distributed.
_initialize_distributed()
# Autoresume.
_init_autoresume()
# Random seeds for reproducibility.
args = get_args()
if args.rank == 0:
print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed)
# Write arguments to tensorboard.
_write_args_to_tensorboard()
def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
args = get_args()
if torch.distributed.is_initialized():
if args.rank == 0:
print('torch distributed is already initialized, '
'skipping initialization ...', flush=True)
args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size()
device = torch.cuda.current_device()
local_rank = args.rank % torch.cuda.device_count()
assert local_rank == device, \
'expected local-rank to be the same as rank % device-count.'
else:
if args.rank == 0:
print('> initializing torch distributed ...', flush=True)
# Manually set the device ids.
device = args.rank % torch.cuda.device_count()
if args.local_rank is not None:
assert args.local_rank == device, \
'expected local-rank to be the same as rank % device-count.'
else:
args.local_rank = device
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
def _init_autoresume():
"""Set autoresume start time."""
autoresume = get_adlr_autoresume()
if autoresume:
torch.distributed.barrier()
autoresume.init()
torch.distributed.barrier()
def _set_random_seed(seed):
"""Set random seed for reproducability."""
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
def _write_args_to_tensorboard():
"""Write arguments to tensorboard."""
args = get_args()
writer = get_tensorboard_writer()
if writer:
for arg in vars(args):
writer.add_text(arg, str(getattr(args, arg)))
...@@ -12,59 +12,68 @@ ...@@ -12,59 +12,68 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""PyTorch DataLoader for TFRecords"""
import torch """Learning rate decay functions."""
from torch.optim.lr_scheduler import _LRScheduler
import math
from megatron.utils import print_rank_0 import math
from megatron import print_rank_0
class AnnealingLR(_LRScheduler):
"""Anneals the learning rate"""
DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None'] class AnnealingLR(object):
"""Anneals the learning rate."""
def __init__(self, optimizer, start_lr, warmup_iter, num_iters, def __init__(self, optimizer, start_lr,
decay_style=None, last_iter=-1, min_lr=0.0, warmup_iter, total_iters,
decay_style, last_iter, min_lr=0.0,
use_checkpoint_lr_scheduler=True, use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False): override_lr_scheduler=False):
# Class values.
self.optimizer = optimizer self.optimizer = optimizer
self.start_lr = start_lr self.start_lr = start_lr
self.min_lr = min_lr self.min_lr = min_lr
self.warmup_iter = warmup_iter self.warmup_iter = warmup_iter
self.num_iters = last_iter + 1 self.num_iters = last_iter
self.end_iter = num_iters self.end_iter = total_iters
self.decay_style = decay_style.lower() if isinstance(decay_style, str) \ assert self.end_iter > 0
else None self.decay_style = decay_style
self.override_lr_scheduler = override_lr_scheduler self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler: if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, 'both override and '\ assert not self.use_checkpoint_lr_scheduler, 'both override and '\
'use-checkpoint are set.' 'use-checkpoint are set.'
# Set the learning rate
self.step(self.num_iters) self.step(self.num_iters)
if torch.distributed.get_rank() == 0:
print('learning rate decaying', decay_style) print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
def get_lr(self): def get_lr(self):
# https://openreview.net/pdf?id=BJYwwY9ll pg. 4 """Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter) num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter)
# Warmup.
if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
return float(self.start_lr) * num_iters_ / self.warmup_iter return float(self.start_lr) * num_iters_ / self.warmup_iter
num_iters_ = num_iters_ - self.warmup_iter
if self.decay_style == 'linear':
lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter
elif self.decay_style == 'cosine':
lr = self.start_lr / 2.0 * (math.cos(
math.pi * num_iters_ / self.end_iter) + 1)
elif self.decay_style == 'exponential':
# exp(-0.693) = 1/2
lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter)
else: else:
if self.decay_style == self.DECAY_STYLES[0]: lr = self.start_lr
lr = self.start_lr * ((self.end_iter - (num_iters_ - self.warmup_iter)) / self.end_iter) return max(lr, self.min_lr)
elif self.decay_style == self.DECAY_STYLES[1]:
lr = self.start_lr / 2.0 * (math.cos(math.pi * (num_iters_ - self.warmup_iter) / self.end_iter) + 1)
elif self.decay_style == self.DECAY_STYLES[2]:
# exp(-0.693) = 1/2
lr = self.start_lr * math.exp(-0.693 * (num_iters_ - self.warmup_iter) / self.end_iter)
else:
lr = self.start_lr
return max(lr, self.min_lr)
def step(self, step_num=None): def step(self, step_num=None):
"""Set lr for all parameters groups."""
if step_num is None: if step_num is None:
step_num = self.num_iters + 1 step_num = self.num_iters + 1
self.num_iters = step_num self.num_iters = step_num
...@@ -72,42 +81,46 @@ class AnnealingLR(_LRScheduler): ...@@ -72,42 +81,46 @@ class AnnealingLR(_LRScheduler):
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group['lr'] = new_lr group['lr'] = new_lr
def state_dict(self): def state_dict(self):
sd = { state_dict = {
'start_lr': self.start_lr, 'start_lr': self.start_lr,
'warmup_iter': self.warmup_iter, 'warmup_iter': self.warmup_iter,
'num_iters': self.num_iters, 'num_iters': self.num_iters,
'decay_style': self.decay_style, 'decay_style': self.decay_style,
'end_iter': self.end_iter, 'end_iter': self.end_iter,
'min_lr': self.min_lr 'min_lr': self.min_lr
} }
return sd return state_dict
def check_and_set_(self, cls_value, sd_value, name): def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if self.override_lr_scheduler: if self.override_lr_scheduler:
print_rank_0(' > overriding {} value to {}'.format(name, cls_value)) print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
return cls_value return cls_value
else:
if not self.use_checkpoint_lr_scheduler: if not self.use_checkpoint_lr_scheduler:
assert cls_value == sd_value, 'AnnealingLR: class input value' \ assert cls_value == sd_value, 'AnnealingLR: class input value' \
'and checkpoint values for {} do not match'.format(name) 'and checkpoint values for {} do not match'.format(name)
print_rank_0(' > using checkpoint value {} for {}'.format(sd_value, print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
name)) name))
return sd_value return sd_value
def load_state_dict(self, sd): def load_state_dict(self, sd):
self.start_lr = self.check_and_set_(self.start_lr, sd['start_lr'], self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'],
'learning rate') 'learning rate')
self.min_lr = self.check_and_set_(self.min_lr, sd['min_lr'], self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
'minimum learning rate') 'minimum learning rate')
self.warmup_iter = self.check_and_set_(self.warmup_iter, self.warmup_iter = self._check_and_set(self.warmup_iter,
sd['warmup_iter'], sd['warmup_iter'],
'warmup iterations') 'warmup iterations')
self.end_iter = self.check_and_set_(self.end_iter, sd['end_iter'], self.end_iter = self._check_and_set(self.end_iter, sd['end_iter'],
'total number of iterations') 'total number of iterations')
self.decay_style = self.check_and_set_(self.decay_style, self.decay_style = self._check_and_set(self.decay_style,
sd['decay_style'], sd['decay_style'],
'decay style') 'decay style')
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import torch import torch
from megatron import get_args
from megatron.module import MegatronModule from megatron.module import MegatronModule
from .language_model import parallel_lm_logits from .language_model import parallel_lm_logits
...@@ -106,27 +107,10 @@ class BertLMHead(MegatronModule): ...@@ -106,27 +107,10 @@ class BertLMHead(MegatronModule):
class BertModel(MegatronModule): class BertModel(MegatronModule):
"""Bert Language model.""" """Bert Language model."""
def __init__(self, def __init__(self, num_tokentypes=2, add_binary_head=True,
num_layers, ict_head_size=None, parallel_output=True):
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
add_binary_head=False,
ict_head_size=None,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=0,
parallel_output=True,
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
super(BertModel, self).__init__() super(BertModel, self).__init__()
args = get_args()
self.add_binary_head = add_binary_head self.add_binary_head = add_binary_head
self.ict_head_size = ict_head_size self.ict_head_size = ict_head_size
...@@ -134,46 +118,31 @@ class BertModel(MegatronModule): ...@@ -134,46 +118,31 @@ class BertModel(MegatronModule):
assert not (self.add_binary_head and self.add_ict_head) assert not (self.add_binary_head and self.add_ict_head)
self.parallel_output = parallel_output self.parallel_output = parallel_output
init_method = init_method_normal(init_method_std) init_method = init_method_normal(args.init_method_std)
add_pooler = self.add_binary_head or self.add_ict_head add_pooler = self.add_binary_head or self.add_ict_head
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers, attention_mask_func=bert_attention_mask_func,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=add_pooler, add_pooler=add_pooler,
attention_mask_func=bert_attention_mask_func,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(init_method_std, scaled_init_method=scaled_init_method)
num_layers),
residual_connection_post_layernorm=False,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
if not self.add_ict_head: if not self.add_ict_head:
self.lm_head = BertLMHead( self.lm_head = BertLMHead(
self.language_model.embedding.word_embeddings.weight.size(0), self.language_model.embedding.word_embeddings.weight.size(0),
hidden_size, init_method, layernorm_epsilon, parallel_output) args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
self._lm_head_key = 'lm_head' self._lm_head_key = 'lm_head'
if self.add_binary_head: if self.add_binary_head:
self.binary_head = get_linear_layer(hidden_size, 2, init_method) self.binary_head = get_linear_layer(args.hidden_size, 2,
init_method)
self._binary_head_key = 'binary_head' self._binary_head_key = 'binary_head'
elif self.add_ict_head: elif self.add_ict_head:
self.ict_head = get_linear_layer(hidden_size, ict_head_size, init_method) self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method)
self._ict_head_key = 'ict_head' self._ict_head_key = 'ict_head'
def forward(self, input_ids, attention_mask, def forward(self, input_ids, attention_mask, tokentype_ids=None):
tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask( extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype) attention_mask, next(self.language_model.parameters()).dtype)
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classification model."""
import torch
from megatron import get_args
from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from megatron import print_rank_0
class Classification(MegatronModule):
def __init__(self, num_classes, num_tokentypes=2):
super(Classification, self).__init__()
args = get_args()
self.num_classes = num_classes
init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
# Multi-choice head.
self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
self.classification_head = get_linear_layer(args.hidden_size,
self.num_classes,
init_method)
self._classification_head_key = 'classification_head'
def forward(self, input_ids, attention_mask, tokentype_ids):
extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# Output.
classification_output = self.classification_dropout(pooled_output)
classification_logits = self.classification_head(classification_output)
# Reshape back to separate choices.
classification_logits = classification_logits.view(-1, self.num_classes)
return classification_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._classification_head_key] \
= self.classification_head.state_dict(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self._classification_head_key in state_dict:
self.classification_head.load_state_dict(
state_dict[self._classification_head_key], strict=strict)
else:
print_rank_0('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._classification_head_key))
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import torch import torch
from megatron import get_args
from megatron.module import MegatronModule from megatron.module import MegatronModule
from .language_model import parallel_lm_logits from .language_model import parallel_lm_logits
...@@ -34,53 +35,24 @@ def gpt2_attention_mask_func(attention_scores, ltor_mask): ...@@ -34,53 +35,24 @@ def gpt2_attention_mask_func(attention_scores, ltor_mask):
class GPT2Model(MegatronModule): class GPT2Model(MegatronModule):
"""GPT-2 Language model.""" """GPT-2 Language model."""
def __init__(self, def __init__(self, num_tokentypes=0, parallel_output=True):
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=0,
parallel_output=True,
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
super(GPT2Model, self).__init__() super(GPT2Model, self).__init__()
args = get_args()
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers, attention_mask_func=gpt2_attention_mask_func,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
attention_mask_func=gpt2_attention_mask_func, init_method=init_method_normal(args.init_method_std),
checkpoint_activations=checkpoint_activations, scaled_init_method=scaled_init_method_normal(args.init_method_std,
checkpoint_num_layers=checkpoint_num_layers, args.num_layers))
layernorm_epsilon=layernorm_epsilon,
init_method=init_method_normal(init_method_std),
scaled_init_method=scaled_init_method_normal(init_method_std,
num_layers),
residual_connection_post_layernorm=False,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False): tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
# Language model. # Language model.
lm_output = self.language_model(input_ids, lm_output = self.language_model(input_ids,
...@@ -94,10 +66,13 @@ class GPT2Model(MegatronModule): ...@@ -94,10 +66,13 @@ class GPT2Model(MegatronModule):
lm_output, presents = lm_output lm_output, presents = lm_output
# Output. # Output.
parallel_output = self.parallel_output
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits( output = parallel_lm_logits(
lm_output, lm_output,
self.language_model.embedding.word_embeddings.weight, self.language_model.embedding.word_embeddings.weight,
self.parallel_output) parallel_output)
if get_key_value: if get_key_value:
output = [output, presents] output = [output, presents]
......
...@@ -18,13 +18,13 @@ ...@@ -18,13 +18,13 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from megatron.module import MegatronModule
from .transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from .transformer import TransformerHyperparameters from megatron.model.utils import gelu
from .utils import gelu from megatron.model.utils import get_linear_layer
from .utils import get_linear_layer
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...@@ -40,52 +40,20 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -40,52 +40,20 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
# Gather if needed. # Gather if needed.
if parallel_output: if parallel_output:
return logits_parallel return logits_parallel
else:
return mpu.gather_from_model_parallel_region(logits_parallel) return mpu.gather_from_model_parallel_region(logits_parallel)
def get_language_model(num_layers, def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
vocab_size, init_method, scaled_init_method):
hidden_size, """Build language model and return along with the key to save."""
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
num_tokentypes,
attention_mask_func,
add_pooler,
checkpoint_activations,
checkpoint_num_layers,
layernorm_epsilon,
init_method,
scaled_init_method,
residual_connection_post_layernorm,
apply_query_key_layer_scaling,
attention_softmax_in_fp32):
# Transformer hyperparameters.
transformer_hparams = TransformerHyperparameters(
hidden_size=hidden_size,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
mlp_activation_func=gelu,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method,
output_layer_init_method=scaled_init_method,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
apply_residual_connection_post_layernorm=residual_connection_post_layernorm,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
# Language model. # Language model.
language_model = TransformerLanguageModel( language_model = TransformerLanguageModel(
transformer_hparams=transformer_hparams,
attention_mask_func=attention_mask_func, attention_mask_func=attention_mask_func,
vocab_size=vocab_size, mlp_activation_func=gelu,
max_sequence_length=max_sequence_length, init_method=init_method,
embedding_dropout_prob=embedding_dropout_prob, output_layer_init_method=scaled_init_method,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=add_pooler) add_pooler=add_pooler)
# key used for checkpoints. # key used for checkpoints.
...@@ -293,33 +261,33 @@ class TransformerLanguageModel(MegatronModule): ...@@ -293,33 +261,33 @@ class TransformerLanguageModel(MegatronModule):
will ignore this embedding will ignore this embedding
""" """
def __init__(self, def __init__(self,
transformer_hparams,
attention_mask_func, attention_mask_func,
vocab_size, mlp_activation_func,
max_sequence_length, init_method,
embedding_dropout_prob, output_layer_init_method,
num_tokentypes=0, num_tokentypes=0,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModel, self).__init__() super(TransformerLanguageModel, self).__init__()
args = get_args()
self.hidden_size = transformer_hparams['hidden_size'] self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
self.init_method = transformer_hparams['init_method'] self.init_method = init_method
self.add_pooler = add_pooler self.add_pooler = add_pooler
# Embeddings # Embeddings
self.embedding = Embedding(self.hidden_size, self.embedding = Embedding(self.hidden_size,
vocab_size, args.padded_vocab_size,
max_sequence_length, args.max_position_embeddings,
embedding_dropout_prob, args.hidden_dropout,
self.init_method, self.init_method,
self.num_tokentypes) self.num_tokentypes)
self._embedding_key = 'embedding' self._embedding_key = 'embedding'
# Transformer # Transformer
self.transformer = ParallelTransformer( self.transformer = ParallelTransformer(
transformer_hparams, attention_mask_func, mlp_activation_func,
attention_mask_func) self.init_method, output_layer_init_method)
self._transformer_key = 'transformer' self._transformer_key = 'transformer'
# Pooler # Pooler
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multiple choice model."""
import torch
from megatron import get_args
from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from megatron import print_rank_0
class MultipleChoice(MegatronModule):
def __init__(self, num_tokentypes=2):
super(MultipleChoice, self).__init__()
args = get_args()
init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
# Multi-choice head.
self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
self.multichoice_head = get_linear_layer(args.hidden_size, 1,
init_method)
self._multichoice_head_key = 'multichoice_head'
def forward(self, input_ids, attention_mask, tokentype_ids):
# [batch, choices, sequence] --> [batch * choices, sequence] -->
# transformer --> [batch, choices] --> softmax
# Ensure the shape is [batch-size, choices, sequence]
assert len(input_ids.shape) == 3
assert len(attention_mask.shape) == 3
assert len(tokentype_ids.shape) == 3
# Reshape and treat choice dimension the same as batch.
num_choices = input_ids.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1))
attention_mask = attention_mask.view(-1, attention_mask.size(-1))
tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# Output.
multichoice_output = self.multichoice_dropout(pooled_output)
multichoice_logits = self.multichoice_head(multichoice_output)
# Reshape back to separate choices.
multichoice_logits = multichoice_logits.view(-1, num_choices)
return multichoice_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._multichoice_head_key] \
= self.multichoice_head.state_dict(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self._multichoice_head_key in state_dict:
self.multichoice_head.load_state_dict(
state_dict[self._multichoice_head_key], strict=strict)
else:
print_rank_0('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._multichoice_head_key))
...@@ -20,6 +20,7 @@ import math ...@@ -20,6 +20,7 @@ import math
import torch import torch
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from megatron.module import MegatronModule
...@@ -45,85 +46,6 @@ from megatron.module import MegatronModule ...@@ -45,85 +46,6 @@ from megatron.module import MegatronModule
unmaksed-attention-scores, attention-mask) unmaksed-attention-scores, attention-mask)
""" """
class TransformerHyperparameters:
"""Hyperparameters used to build and run the transformer.
Arguments:
hidden_size: hidden size (h)
num_layers: number of layers (l)
num_attention_heads: number of attention heads (n)
attention_dropout_prob: dropout probability for the attention
probabiliies
output_dropout_prob: dropout probability for the output
layers (attention output and mlp output)
mlp_activation_func: activation function for the mlp layer
layernorm_epsilon: tolerance parameters used for layer norm
dividions
init_method: init method used for all weights except layer
norm and output weights
output_layer_init_method: init method for output weights (
attention output and mlp output)
checkpoint_activations: flag to use activation checkpointing
checkpoint_num_layers: number of layers use in each chunk of
activation checkpointing
apply_residual_connection_post_layernorm: Take the post layer-norm
values for resudual connecton. BERT: True, GPT-2: False
"""
def __init__(self,
hidden_size=None,
num_layers=None,
num_attention_heads=None,
attention_dropout_prob=None,
output_dropout_prob=None,
mlp_activation_func=None,
layernorm_epsilon=None,
init_method=None,
output_layer_init_method=None,
checkpoint_activations=None,
checkpoint_num_layers=None,
apply_residual_connection_post_layernorm=None,
apply_query_key_layer_scaling=None,
attention_softmax_in_fp32=None):
self.params_dict = {}
self.params_dict['hidden_size'] = hidden_size
self.params_dict['num_layers'] = num_layers
self.params_dict['num_attention_heads'] = num_attention_heads
self.params_dict['attention_dropout_prob'] = attention_dropout_prob
self.params_dict['output_dropout_prob'] = output_dropout_prob
self.params_dict['mlp_activation_func'] = mlp_activation_func
self.params_dict['layernorm_epsilon'] = layernorm_epsilon
self.params_dict['init_method'] = init_method
self.params_dict['output_layer_init_method'] = output_layer_init_method
self.params_dict['checkpoint_activations'] = checkpoint_activations
self.params_dict['checkpoint_num_layers'] = checkpoint_num_layers
self.params_dict['apply_residual_connection_post_layernorm'] \
= apply_residual_connection_post_layernorm
self.params_dict['apply_query_key_layer_scaling'] \
= apply_query_key_layer_scaling
self.params_dict['attention_softmax_in_fp32'] \
= attention_softmax_in_fp32
def __getitem__(self, key):
"""Custom retrieval with error checks."""
try:
value = self.params_dict[key]
except KeyError:
raise Exception(
'could not find {} in transformer hyperparameters'.format(key))
except Exception as e:
print('unexpected error in transformer hyperparameters:', e)
raise Exception()
else:
assert value is not None, \
'parameter value for {} is not set in transformer '\
'hyperparameters'.format(key)
return value
raise Exception('should not be here')
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
"""MLP. """MLP.
...@@ -133,26 +55,28 @@ class ParallelMLP(MegatronModule): ...@@ -133,26 +55,28 @@ class ParallelMLP(MegatronModule):
applied. applied.
""" """
def __init__(self, hyperparameters): def __init__(self, mlp_activation_func, init_method,
output_layer_init_method):
super(ParallelMLP, self).__init__() super(ParallelMLP, self).__init__()
args = get_args()
# Project to 4h. # Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear( self.dense_h_to_4h = mpu.ColumnParallelLinear(
hyperparameters['hidden_size'], args.hidden_size,
4*hyperparameters['hidden_size'], 4*args.hidden_size,
gather_output=False, gather_output=False,
init_method=hyperparameters['init_method']) init_method=init_method)
self.activation_func = hyperparameters['mlp_activation_func'] self.activation_func = mlp_activation_func
# Project back to h. # Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear( self.dense_4h_to_h = mpu.RowParallelLinear(
4*hyperparameters['hidden_size'], 4*args.hidden_size,
hyperparameters['hidden_size'], args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=hyperparameters['output_layer_init_method']) init_method=output_layer_init_method)
self.dropout = torch.nn.Dropout(hyperparameters['output_dropout_prob']) self.dropout = torch.nn.Dropout(args.hidden_dropout)
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -174,51 +98,47 @@ class ParallelSelfAttention(MegatronModule): ...@@ -174,51 +98,47 @@ class ParallelSelfAttention(MegatronModule):
Self-attention layer takes input with size [b, s, h] Self-attention layer takes input with size [b, s, h]
and returns output of the same size. and returns output of the same size.
""" """
def __init__(self, attention_mask_func, init_method,
def __init__(self, hyperparameters, attention_mask_func, layer_number): output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__() super(ParallelSelfAttention, self).__init__()
args = get_args()
self.attention_mask_func = attention_mask_func self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling \ self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
= hyperparameters['apply_query_key_layer_scaling'] self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
self.attention_softmax_in_fp32 \
= hyperparameters['attention_softmax_in_fp32']
if self.apply_query_key_layer_scaling: if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size() world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide( self.hidden_size_per_partition = mpu.divide(args.hidden_size,
hyperparameters['hidden_size'], world_size) world_size)
self.hidden_size_per_attention_head = mpu.divide( self.hidden_size_per_attention_head = mpu.divide(
hyperparameters['hidden_size'], args.hidden_size, args.num_attention_heads)
hyperparameters['num_attention_heads'])
self.num_attention_heads_per_partition = mpu.divide( self.num_attention_heads_per_partition = mpu.divide(
hyperparameters['num_attention_heads'], world_size) args.num_attention_heads, world_size)
# Strided linear layer. # Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear( self.query_key_value = mpu.ColumnParallelLinear(
hyperparameters['hidden_size'], args.hidden_size,
3*hyperparameters['hidden_size'], 3*args.hidden_size,
stride=3, stride=3,
gather_output=False, gather_output=False,
init_method=hyperparameters['init_method']) init_method=init_method)
# Dropout. Note that for a single iteration, this layer will generate # Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but # different outputs on different number of parallel partitions but
# on average it should not be partition dependent. # on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout( self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
hyperparameters['attention_dropout_prob'])
# Output. # Output.
self.dense = mpu.RowParallelLinear( self.dense = mpu.RowParallelLinear(
hyperparameters['hidden_size'], args.hidden_size,
hyperparameters['hidden_size'], args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=hyperparameters['output_layer_init_method']) init_method=output_layer_init_method)
self.output_dropout = torch.nn.Dropout( self.output_dropout = torch.nn.Dropout(args.hidden_dropout)
hyperparameters['output_dropout_prob'])
def _transpose_for_scores(self, tensor): def _transpose_for_scores(self, tensor):
...@@ -369,30 +289,34 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -369,30 +289,34 @@ class ParallelTransformerLayer(MegatronModule):
Transformore layer takes input with size [b, s, h] and returns an Transformore layer takes input with size [b, s, h] and returns an
output of the same size. output of the same size.
""" """
def __init__(self, hyperparameters, attention_mask_func, layer_number): def __init__(self, attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number):
args = get_args()
super(ParallelTransformerLayer, self).__init__() super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number self.layer_number = layer_number
self.apply_residual_connection_post_layernorm \ self.apply_residual_connection_post_layernorm \
= hyperparameters['apply_residual_connection_post_layernorm'] = args.apply_residual_connection_post_layernorm
# Layernorm on the input data. # Layernorm on the input data.
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
hyperparameters['hidden_size'], args.hidden_size,
eps=hyperparameters['layernorm_epsilon']) eps=args.layernorm_epsilon)
# Self attention. # Self attention.
self.attention = ParallelSelfAttention( self.attention = ParallelSelfAttention(attention_mask_func, init_method,
hyperparameters, attention_mask_func, layer_number) output_layer_init_method,
layer_number)
# Layernorm on the input data. # Layernorm on the input data.
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
hyperparameters['hidden_size'], args.hidden_size,
eps=hyperparameters['layernorm_epsilon']) eps=args.layernorm_epsilon)
# MLP # MLP
self.mlp = ParallelMLP(hyperparameters) self.mlp = ParallelMLP(mlp_activation_func, init_method,
output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
...@@ -434,25 +358,28 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -434,25 +358,28 @@ class ParallelTransformerLayer(MegatronModule):
class ParallelTransformer(MegatronModule): class ParallelTransformer(MegatronModule):
"""Transformer class.""" """Transformer class."""
def __init__(self, hyperparameters, attention_mask_func): def __init__(self, attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args()
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.checkpoint_activations = hyperparameters['checkpoint_activations'] self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = hyperparameters['checkpoint_num_layers'] self.checkpoint_num_layers = args.checkpoint_num_layers
def get_layer(layer_number): def get_layer(layer_number):
return ParallelTransformerLayer( return ParallelTransformerLayer(
hyperparameters, attention_mask_func, layer_number) attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number)
# Transformer layers. # Transformer layers.
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[get_layer(i+1) for i in range(hyperparameters['num_layers'])]) [get_layer(i+1) for i in range(args.num_layers)])
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
hyperparameters['hidden_size'], args.hidden_size,
eps=hyperparameters['layernorm_epsilon']) eps=args.layernorm_epsilon)
def _checkpointed_forward(self, hidden_states, attention_mask): def _checkpointed_forward(self, hidden_states, attention_mask):
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Megatron Module""" """Megatron Module"""
import torch import torch
......
...@@ -13,150 +13,99 @@ ...@@ -13,150 +13,99 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Sample Generate GPT2""" """Utilities for generating text."""
import os
import random
import json
import copy import copy
import numpy as np import json
import torch import os
import torch.nn.functional as F
import argparse
import time import time
from arguments import get_args
from megatron.utils import Timers
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import load_checkpoint
from megatron.data_utils import make_tokenizer
from configure_data import configure_data
from megatron import mpu
from megatron.fp16 import FP16_Module
from megatron.model import GPT2Model
from megatron.model import DistributedDataParallel as DDP
from megatron.utils import print_rank_0
def get_model(args):
"""Build the model."""
print_rank_0('building GPT2 model ...')
model = GPT2Model(num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
parallel_output=False)
if mpu.get_data_parallel_rank() == 0: import torch
print(' > number of parameters on model parallel rank {}: {}'.format( import torch.nn.functional as F
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation.
model.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16:
model = FP16_Module(model)
# Wrap model for distributed training.
model = DDP(model)
return model
def setup_model(args):
"""Setup model and optimizer."""
model = get_model(args)
if args.load is not None:
_ = load_checkpoint(
model, None, None, args)
return model from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids
def get_batch(context_tokens, args): def get_batch(context_tokens):
tokens = context_tokens """Generate batch from context tokens."""
tokens = tokens.view(args.batch_size, -1).contiguous() args = get_args()
device = args.device tokenizer = get_tokenizer()
tokens = tokens.to(device)
# Get the masks and postition ids. # Move to GPU.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens, tokens,
args.eod_token, tokenizer.eod,
args.reset_position_ids, args.reset_position_ids,
args.reset_attention_mask, args.reset_attention_mask,
False) args.eod_mask_loss,
args.fp16)
# Fp16 conversion.
if args.fp16:
attention_mask = attention_mask.half()
return tokens, attention_mask, position_ids return tokens, attention_mask, position_ids
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# This function has been mostly taken from huggingface conversational ai code at """ This function has been mostly taken from huggingface conversational
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313 """
if top_k > 0: if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k # Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
if top_p > 0.0: if top_p > 0.0:
#convert to 1D # Cconvert to 1D
# logits=logits.view(logits.size()[1]).contiguous() sorted_logits, sorted_indices = torch.sort(
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
dim=-1)
# Remove tokens with cumulative probability above the threshold # Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold # Shift the indices to the right to keep also the first token
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() # above the threshold
sorted_indices_to_remove[..., 1:] \
= sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0 sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)): for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value logits[i][indices_to_remove] = filter_value
#going back to 2D
# logits=logits.view(1, -1).contiguous()
return logits return logits
def generate_samples_input_from_file(model, tokenizer, args):
if args.sample_input_file == "": def generate_samples_input_from_file(model):
if mpu.get_model_parallel_rank() == 0:
print("args.sample_input_file CAN NOT BE empty!\n")
return
args = get_args()
tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r") fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines() all_raw_text = fname.readlines()
input_count = len(all_raw_text) input_count = len(all_raw_text)
input_pos = 0 input_pos = 0
if args.sample_output_file == "": if args.sample_output_file is None:
print("Argument: sample-output-file can't be empty, setting it to\n") sample_output_file = args.sample_input_file + ".out"
print("\t args.sample_input_file.out") print('could not find `sample-output-file`, setting '
args.sample_output_file = args.sample_input_file+".out" 'it to {}'.format(sample_output_file))
fname_out = open(args.sample_output_file, "w+") fname_out = open(sample_output_file, "w+")
context_count=0 context_count = 0
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
while True: while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs=0 terminate_runs = 0
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos] raw_text = all_raw_text[input_pos]
...@@ -167,63 +116,62 @@ def generate_samples_input_from_file(model, tokenizer, args): ...@@ -167,63 +116,62 @@ def generate_samples_input_from_file(model, tokenizer, args):
if "stop" in raw_text: if "stop" in raw_text:
terminate_runs = 1 terminate_runs = 1
else: else:
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization context_tokens = tokenizer.tokenize(raw_text)
context_length = len(context_tokens) context_length = len(context_tokens)
if context_length >=args.seq_length//2: if context_length >= (args.seq_length // 2):
print("\nContext length", context_length, \ print("\nContext length", context_length, \
"\nPlease give smaller context (half of the sequence length)!") "\nPlease give smaller context (half of the "
"sequence length)!", flush=True)
continue continue
else: else:
context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens) context_length = len(context_tokens)
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item() terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1: if terminate_runs == 1:
return return
start_time = time.time() token_stream = get_token_stream(model, [context_tokens])
token_stream = get_token_stream(model, [context_tokens], tokenizer, args) for _, decode_tokens in enumerate(token_stream):
for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):] trim_decode_tokens = tokenizer.detokenize(
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True) decode_tokens)[len(raw_text):]
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
fname_out.write("\nContext:") fname_out.write("\nContext:")
fname_out.write(raw_text) fname_out.write(raw_text)
fname_out.write("\n\nMegatron-LM:") fname_out.write("\n\nMegatron-LM:")
fname_out.write(trim_decode_tokens) fname_out.write(trim_decode_tokens)
#fname_out.write(trim_decode_tokens.replace("\n", "\n\n"))
fname_out.write("\n") fname_out.write("\n")
raw_text = None raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1 context_count += 1
def generate_samples_interactive(model, tokenizer, args):
print_frequency = 24
context_count=0 def generate_samples_interactive(model, print_frequency=24):
args = get_args()
tokenizer = get_tokenizer()
context_count = 0
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
while True: while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs=0 terminate_runs = 0
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
...@@ -231,83 +179,83 @@ def generate_samples_interactive(model, tokenizer, args): ...@@ -231,83 +179,83 @@ def generate_samples_interactive(model, tokenizer, args):
while not raw_text: while not raw_text:
print('Prompt should not be empty!') print('Prompt should not be empty!')
raw_text = input("\nContext prompt (stop to exit) >>> ") raw_text = input("\nContext prompt (stop to exit) >>> ")
if "stop" in raw_text: if "stop" in raw_text:
terminate_runs = 1 terminate_runs = 1
else: else:
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization context_tokens = tokenizer.tokenize(raw_text)
context_length = len(context_tokens) context_length = len(context_tokens)
if context_length >=args.seq_length//2: if context_length >= (args.seq_length // 2):
print("\nContext length", context_length, \ print("\nContext length", context_length, \
"\nPlease give smaller context (half of the sequence length)!") "\nPlease give smaller context (half of the "
"sequence length)!", flush=True)
continue continue
else: else:
context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens) context_length = len(context_tokens)
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item() terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1: if terminate_runs == 1:
return return
start_time = time.time() token_stream = get_token_stream(model, [context_tokens])
token_stream = get_token_stream(model, [context_tokens], tokenizer, args)
for counter, decode_tokens in enumerate(token_stream): for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0 and counter % print_frequency == 0: if mpu.get_model_parallel_rank() == 0 and \
counter % print_frequency == 0:
os.system('clear') os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):] trim_decode_tokens = tokenizer.detokenize(
#print("\nGPT2:", trim_decode_tokens, flush=True) decode_tokens)[len(raw_text):]
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):] trim_decode_tokens = tokenizer.detokenize(
#print("\nGPT2:", trim_decode_tokens, flush=True) decode_tokens)[len(raw_text):]
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
raw_text = None raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group()) torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1 context_count += 1
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
input("\nPress any key to continue >>>") input("\nPress any key to continue >>>")
def generate_samples_unconditional(model, tokenizer, args):
def generate_samples_unconditional(model):
args = get_args()
tokenizer = get_tokenizer()
num_samples = args.num_samples num_samples = args.num_samples
context_tokens = [[tokenizer.get_command('pad').Id] for _ in range(args.batch_size)] context_tokens = [[tokenizer.eod]
samples = [] for _ in range(args.batch_size)]
# with open(args.genfile, 'w') as f:
ctr = 0 ctr = 0
while True: while True:
start_time = time.time() start_time = time.time()
for token_stream in get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args): for token_stream in get_token_stream(model,
copy.deepcopy(context_tokens)):
pass pass
# token_stream = list(get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args))
if ctr%args.log_interval == 0: if ctr%args.log_interval == 0:
print('Avg s/batch:', (time.time()-start_time)/min(args.log_interval, ctr+1)) print('Avg s/batch:',
(time.time() - start_time) / min(args.log_interval, ctr + 1))
start_time = time.time() start_time = time.time()
length = len(token_stream) length = len(token_stream)
token_batch = token_stream[0].cpu().numpy().tolist() token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist() length_batch = token_stream[1].cpu().numpy().tolist()
for tokens, length in zip(token_batch, length_batch): for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length-1] tokens = tokens[1:length-1]
text = tokenizer.DecodeIds(tokens) text = tokenizer.detokenize(tokens)
is_finished = length < args.seq_length - 1 is_finished = length < args.seq_length - 1
datum = {'text': text, 'length': length-1, 'finished': is_finished} datum = {'text': text, 'length': length-1, 'finished': is_finished}
yield datum yield datum
...@@ -317,65 +265,73 @@ def generate_samples_unconditional(model, tokenizer, args): ...@@ -317,65 +265,73 @@ def generate_samples_unconditional(model, tokenizer, args):
if ctr >= num_samples: if ctr >= num_samples:
break break
def write_and_generate_samples_unconditional(model, tokenizer, args):
def generate_and_write_samples_unconditional(model):
args = get_args()
assert args.genfile is not None assert args.genfile is not None
with open(args.genfile, 'w') as f: with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model, tokenizer, args): for datum in generate_samples_unconditional(model):
f.write(json.dumps(datum)+'\n') f.write(json.dumps(datum)+'\n')
def pad_batch(batch, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id def pad_batch(batch, pad_id, args):
context_lengths = [] context_lengths = []
for tokens in batch: for tokens in batch:
context_length = len(tokens) context_length = len(tokens)
if context_length < args.seq_length: if context_length < args.seq_length:
tokens.extend([pad_id]*(args.seq_length-context_length)) tokens.extend([pad_id]*(args.seq_length - context_length))
context_lengths.append(context_length) context_lengths.append(context_length)
return batch, context_lengths return batch, context_lengths
def get_token_stream(model, context_tokens, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id def get_token_stream(model, context_tokens):
# context_length = len(context_tokens)
# if context_length < args.seq_length: args = get_args()
# context_tokens = context_tokens + [pad_id] * (args.seq_length - context_length) tokenizer = get_tokenizer()
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args)
context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eod, args)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths) context_length_tensor = torch.cuda.LongTensor(context_lengths)
# context_length_tensor = torch.cuda.LongTensor([context_length])
torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) torch.distributed.broadcast(context_length_tensor,
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
context_length = context_length_tensor.min().item() context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, args) tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
counter = 0 batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
org_context_length = context_length context_length_tensor,
attention_mask, position_ids)
layer_past = None
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, tokenizer, args)
for tokens, lengths in batch_token_iterator: for tokens, lengths in batch_token_iterator:
context_length += 1 context_length += 1
yield tokens[:, :context_length], lengths yield tokens[:, :context_length], lengths
def switch(val1, val2, boolean): def switch(val1, val2, boolean):
boolean = boolean.type_as(val1) boolean = boolean.type_as(val1)
return (1-boolean)*val1 + boolean*val2 return (1 - boolean) * val1 + boolean * val2
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None):
if isinstance(model, DDP): def sample_sequence_batch(model, context_tokens, context_lengths,
model = model.module attention_mask, position_ids,
if isinstance(model, FP16_Module): maxlen=None, type_ids=None):
model = model.module
original_output_parallel = model.parallel_output args = get_args()
model.parallel_output = False tokenizer = get_tokenizer()
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
context_length = context_lengths.min().item() context_length = context_lengths.min().item()
eos_id = tokenizer.get_command('eos').Id eos_id = tokenizer.eod
counter = 0 counter = 0
org_context_length = context_length org_context_length = context_length
...@@ -390,11 +346,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask ...@@ -390,11 +346,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
maxlen = org_context_length + args.out_seq_length maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda()*maxlen lengths = torch.ones([batch_size]).long().cuda()*maxlen
while context_length <= (maxlen): while context_length <= (maxlen):
if args.recompute: if args.recompute:
logits = model(tokens, position_ids, attention_mask, tokentype_ids=type_ids) logits = model(tokens,
position_ids,
attention_mask,
tokentype_ids=type_ids,
forward_method_parallel_output=False)
logits = logits[:, context_length - 1, :] logits = logits[:, context_length - 1, :]
else: else:
types2use = None types2use = None
...@@ -404,113 +364,48 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask ...@@ -404,113 +364,48 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, :context_length] types2use = type_ids[:, :context_length]
else: else:
tokens2use = tokens[:, context_length - 1].view(batch_size, -1) tokens2use = tokens[:, context_length - 1].view(
positions2use = position_ids[:, context_length - 1].view(batch_size, -1) batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(batch_size, -1) types2use = type_ids[:, context_length - 1].view(
logits, layer_past = model(tokens2use, positions2use, attention_mask, layer_past=layer_past, get_key_value=True, tokentype_ids=types2use) batch_size, -1)
logits = logits[:, -1].view(batch_size,-1).contiguous() logits, layer_past = model(tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
tokentype_ids=types2use,
forward_method_parallel_output=False)
logits = logits[:, -1].view(batch_size, -1).contiguous()
if args.greedy: if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1) prev = torch.argmax(logits, dim=-1).view(-1)
else: else:
logits = logits.float() logits = logits.float()
logits /= args.temperature logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) logits = top_k_logits(logits, top_k=args.top_k,
top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1) log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1)
print_logits = [] print_logits = []
for p in prev: for p in prev:
print_logits.append([logits[i, p].item() for i in range(batch_size)]) print_logits.append([logits[i, p].item()
for i in range(batch_size)])
started = context_lengths <= context_length started = context_lengths <= context_length
tokens[:, context_length] = switch(tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = switch(
tokens[:, context_length].view(-1), prev, started)
context_length += 1 context_length += 1
counter += 1 counter += 1
done_token = (prev == eos_id).byte() & started.byte() done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool() just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length lengths[just_finished.view(-1)] = context_length
was_done = is_done
is_done = is_done | done_token is_done = is_done | done_token
done = torch.all(is_done) done = torch.all(is_done)
yield tokens, lengths yield tokens, lengths
if done: if done:
break break
model.parallel_output = original_output_parallel
def prepare_tokenizer(args):
tokenizer_args = {
'tokenizer_type': args.tokenizer_type,
'corpus': None,
'model_path': args.tokenizer_path,
'vocab_size': args.vocab_size,
'model_type': args.tokenizer_model_type,
'cache_dir': args.cache_dir}
tokenizer = make_tokenizer(**tokenizer_args)
args.tokenizer_num_tokens = tokenizer.num_tokens
args.tokenizer_num_type_tokens = tokenizer.num_type_tokens
args.eod_token = tokenizer.get_command('eos').Id
after = tokenizer.num_tokens
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
if multiple != 0:
while (after % multiple) != 0:
after += 1
args.vocab_size = after
print("prepare tokenizer done", flush=True)
return tokenizer
def main():
"""Main training program."""
print('Generate Samples')
# Disable CuDNN.
torch.backends.cudnn.enabled = False
# Timer.
timers = Timers()
# Arguments.
args = get_args()
# Pytorch distributed.
initialize_distributed(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
#get the tokenizer
tokenizer = prepare_tokenizer(args)
# Model, optimizer, and learning rate.
model = setup_model(args)
#setting default batch size to 1
# args.batch_size = 1
args.device = torch.cuda.current_device()
#generate samples
if args.num_samples == 0:
args.batch_size = 1
if args.sample_input_file != "":
generate_samples_input_from_file(model, tokenizer, args)
else:
generate_samples_interactive(model, tokenizer, args)
else:
write_and_generate_samples_unconditional(model, tokenizer, args)
if __name__ == "__main__":
main()
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tokenizer import build_tokenizer
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import sys
import json
import logging
import os
import regex as re
from io import open
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class GPT2Tokenizer(object):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
from .file_utils import cached_path
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
if sys.version_info[0] == 2:
token = ''.join(self.byte_encoder[ord(b)] for b in token)
else:
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron tokenizers."""
from abc import ABC
from abc import abstractmethod
from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer
def build_tokenizer(args):
"""Initialize tokenizer."""
if args.rank == 0:
print('> building {} tokenizer ...'.format(args.tokenizer_type),
flush=True)
# Select and instantiate the tokenizer.
assert args.vocab_file is not None
if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True)
elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
else:
raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(args.tokenizer_type))
# Add vocab size.
args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size,
args)
return tokenizer
def _vocab_size_with_padding(orig_vocab_size, args):
"""Pad vocab size so it is divisible by model parallel size and
still having GPU friendly size."""
after = orig_vocab_size
multiple = args.make_vocab_size_divisible_by * \
args.model_parallel_size
while (after % multiple) != 0:
after += 1
if args.rank == 0:
print(' > padded vocab (size: {}) with {} dummy tokens '
'(new size: {})'.format(
orig_vocab_size, after - orig_vocab_size, after), flush=True)
return after
class AbstractTokenizer(ABC):
"""Abstract class for tokenizer."""
def __init__(self, name):
self.name = name
super().__init__()
@property
@abstractmethod
def vocab_size(self):
pass
@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass
@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass
@abstractmethod
def tokenize(self, text):
pass
def detokenize(self, token_ids):
raise NotImplementedError('detokenizer is not implemented for {} '
'tokenizer'.format(self.name))
@property
def cls(self):
raise NotImplementedError('CLS is not provided for {} '
'tokenizer'.format(self.name))
@property
def sep(self):
raise NotImplementedError('SEP is not provided for {} '
'tokenizer'.format(self.name))
@property
def pad(self):
raise NotImplementedError('PAD is not provided for {} '
'tokenizer'.format(self.name))
@property
def eod(self):
raise NotImplementedError('EOD is not provided for {} '
'tokenizer'.format(self.name))
@property
def mask(self):
raise NotImplementedError('MASK is not provided for {} '
'tokenizer'.format(self.name))
class _BertWordPieceTokenizer(AbstractTokenizer):
"""Original BERT wordpiece tokenizer."""
def __init__(self, vocab_file, lower_case=True):
if lower_case:
name = 'BERT Lower Case'
else:
name = 'BERT Upper Case'
super().__init__(name)
self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]']
self.pad_id = self.tokenizer.vocab['[PAD]']
self.mask_id = self.tokenizer.vocab['[MASK]']
@property
def vocab_size(self):
return self.tokenizer.vocab_size()
@property
def vocab(self):
return self.tokenizer.vocab
@property
def inv_vocab(self):
return self.tokenizer.inv_vocab
def tokenize(self, text):
text_tokens = self.tokenizer.tokenize(text)
return self.tokenizer.convert_tokens_to_ids(text_tokens)
@property
def cls(self):
return self.cls_id
@property
def sep(self):
return self.sep_id
@property
def pad(self):
return self.pad_id
@property
def mask(self):
return self.mask_id
class _GPT2BPETokenizer(AbstractTokenizer):
"""Original GPT2 BPE tokenizer."""
def __init__(self, vocab_file, merge_file):
name = 'GPT2 BPE'
super().__init__(name)
self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace',
special_tokens=[], max_len=None)
self.eod_id = self.tokenizer.encoder['<|endoftext|>']
@property
def vocab_size(self):
return len(self.tokenizer.encoder)
@property
def vocab(self):
return self.tokenizer.encoder
@property
def inv_vocab(self):
return self.tokenizer.decoder
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.eod_id
...@@ -13,140 +13,120 @@ ...@@ -13,140 +13,120 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Pretrain utilities""" """Pretrain utilities."""
from datetime import datetime from datetime import datetime
import math import math
import sys
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from arguments import get_args from megatron import get_args
from megatron import get_timers
from megatron import get_tensorboard_writer
from megatron import mpu from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Module from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import enable_adlr_autoresume
from megatron.utils import get_tensorboard_writer
from megatron.utils import initialize_distributed
from megatron.utils import load_checkpoint
from megatron.utils import print_args
from megatron.utils import print_rank_0
from megatron.utils import report_memory from megatron.utils import report_memory
from megatron.utils import save_checkpoint
from megatron.utils import set_random_seed
from megatron.utils import Timers
def run(top_level_message, train_val_test_data_provider, def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
model_provider, forward_step_func): extra_args_provider=None, args_defaults={}):
"""Main training program. """Main training program.
This function will run the followings in the order provided: This function will run the followings in the order provided:
1) get input arguments. 1) initialize Megatron.
2) initialize distributed and seeds. 2) setup model, optimizer and lr schedule using the model_provider.
3) call train_val_test_data_provider to get train/val/test datasets. 3) call train_val_test_data_provider to get train/val/test datasets.
4) setup model, optimizer and lr schedule using the model_provider. 4) train the modle using the forward_step_func.
5) train the modle using the forward_step_func.
Arguments: Arguments:
top_level_message: a meesage to print at the top of the run. train_val_test_data_provider: a function that builds datasets
train_val_test_data_provider: a function that takes `args` as input and returns `train, val, test` dataloaders.
and returns `train, val, test` dataloaders. Note that args are model_provider: a function that returns a vanilla version of the
passed and can be modified in case we need to use some parameters model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
later. For example, we can set vocab size using forward_step_func: a function that takes a `data iterator` and `model`,
args.vocab_size = ... and returns a `loss` scalar with a dictionary with key:values being
and later use this value in `model_provider`. the info we would like to monitor during training, for example
model_provider: a function that takes `args` and returns a vanilla `lm-loss: value`. We also require that this function add
version of the model. By vanilla we mean a simple model on cpu `batch generator` to the timers class.
with no fp16 or ddp. extra_args_provider: a function that takes a parser and adds arguments
forward_step_func: a function that takes a `data iterator`, `model`, to it. It is used for programs to add their own arguments.
`args`, and `timers` and returns a `loss` scalar with a dictionary args_defaults: a dictionary from argument-name to argument-value. It
with key:values being the info we would like to monitor during to set already parse arguments.
training, for example `lm-loss: value`. We also require that this
function add `batch generator` to the timers class.
""" """
# Arguments. # Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
args = get_args() args = get_args()
timers = get_timers()
# Timer. # Model, optimizer, and learning rate.
timers = Timers() timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
# Tensorboard writer timers('model and optimizer').stop()
writer = get_tensorboard_writer(args)
# Initalize.
initialize_megatron(top_level_message, args, writer)
# Data stuff. # Data stuff.
train_data, val_data, test_data = train_val_test_data_provider(args) timers('train/valid/test dataset').start()
train_data, val_data, test_data = train_val_test_data_provider()
# Model, optimizer, and learning rate. timers('train/valid/test dataset').stop()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
args)
# Train, validation, and test data. # Train, validation, and test data.
timers('train/valid/test dataloader').start()
train_data_iterator, val_data_iterator, \ train_data_iterator, val_data_iterator, \
test_data_iterator = get_train_val_test_data_iterators(train_data, test_data_iterator = get_train_val_test_data_iterators(train_data,
val_data, val_data,
test_data, test_data)
args) timers('train/valid/test dataloader').stop()
# Print setup timing.
print_rank_0('done with setups ...')
timers.log(['model and optimizer', 'train/valid/test dataset',
'train/valid/test dataloader'])
print_rank_0('training ...')
iteration = 0 iteration = 0
if args.train_iters > 0: if args.do_train and args.train_iters > 0:
if args.do_train: if args.do_train:
iteration, _ = train(forward_step_func, model, iteration, _ = train(forward_step_func,
optimizer, lr_scheduler, model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, train_data_iterator, val_data_iterator)
timers, args, writer)
if args.do_valid: if args.do_valid:
prefix = 'the end of training for val data' prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
val_data_iterator, model, val_data_iterator, model,
args, writer, iteration, iteration, False)
timers, False)
if args.save and iteration != 0: if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, save_checkpoint(iteration, model, optimizer, lr_scheduler)
lr_scheduler, args)
if args.do_test: if args.do_test:
# Run on test data. # Run on test data.
prefix = 'the end of training for test data' prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model, test_data_iterator, model,
args, None, 0, timers, True) 0, True)
def initialize_megatron(message, args, writer): def get_model(model_provider_func):
""""Initialize distributed, random seed, and autoresume."""
# Pytorch distributed.
initialize_distributed(args)
if torch.distributed.get_rank() == 0:
print(message, flush=True)
print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
def get_model(model_provider_func, args):
"""Build the model.""" """Build the model."""
args = get_args()
# Build model on cpu. # Build model on cpu.
model = model_provider_func(args) model = model_provider_func()
# Print number of parameters. # Print number of parameters.
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
...@@ -164,26 +144,24 @@ def get_model(model_provider_func, args): ...@@ -164,26 +144,24 @@ def get_model(model_provider_func, args):
# Wrap model for distributed training.""" # Wrap model for distributed training."""
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
i = torch.cuda.current_device() i = torch.cuda.current_device()
args.DDP_type = torchDDP model = torchDDP(model, device_ids=[i], output_device=i,
model = args.DDP_type(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group())
process_group=mpu.get_data_parallel_group())
return model return model
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
args.DDP_type = LocalDDP model = LocalDDP(model)
model = args.DDP_type(model)
return model return model
print_rank_0('Unknown DDP implementation specified: {}. ' raise NotImplementedError('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl)) 'Exiting.'.format(args.DDP_impl))
exit() sys.exit()
return model
def get_optimizer(model, args): def get_optimizer(model):
"""Set up the optimizer.""" """Set up the optimizer."""
args = get_args()
# Build parameter groups (weight decay and non-decay). # Build parameter groups (weight decay and non-decay).
while isinstance(model, (args.DDP_type, FP16_Module)): while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
model = model.module model = model.module
param_groups = get_params_for_weight_decay_optimization(model) param_groups = get_params_for_weight_decay_optimization(model)
...@@ -194,8 +172,7 @@ def get_optimizer(model, args): ...@@ -194,8 +172,7 @@ def get_optimizer(model, args):
param.model_parallel = False param.model_parallel = False
# Use Adam. # Use Adam.
optimizer = Adam(param_groups, optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay)
lr=args.lr, weight_decay=args.weight_decay)
# Wrap into fp16 optimizer. # Wrap into fp16 optimizer.
if args.fp16: if args.fp16:
...@@ -210,8 +187,9 @@ def get_optimizer(model, args): ...@@ -210,8 +187,9 @@ def get_optimizer(model, args):
return optimizer return optimizer
def get_learning_rate_scheduler(optimizer, args): def get_learning_rate_scheduler(optimizer):
"""Build the learning rate scheduler.""" """Build the learning rate scheduler."""
args = get_args()
# Add linear learning rate scheduler. # Add linear learning rate scheduler.
if args.lr_decay_iters is not None: if args.lr_decay_iters is not None:
...@@ -219,13 +197,13 @@ def get_learning_rate_scheduler(optimizer, args): ...@@ -219,13 +197,13 @@ def get_learning_rate_scheduler(optimizer, args):
else: else:
num_iters = args.train_iters num_iters = args.train_iters
num_iters = max(1, num_iters) num_iters = max(1, num_iters)
init_step = -1 init_step = 0
warmup_iter = args.warmup * num_iters warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR( lr_scheduler = AnnealingLR(
optimizer, optimizer,
start_lr=args.lr, start_lr=args.lr,
warmup_iter=warmup_iter, warmup_iter=warmup_iter,
num_iters=num_iters, total_iters=num_iters,
decay_style=args.lr_decay_style, decay_style=args.lr_decay_style,
last_iter=init_step, last_iter=init_step,
min_lr=args.min_lr, min_lr=args.min_lr,
...@@ -235,23 +213,26 @@ def get_learning_rate_scheduler(optimizer, args): ...@@ -235,23 +213,26 @@ def get_learning_rate_scheduler(optimizer, args):
return lr_scheduler return lr_scheduler
def setup_model_and_optimizer(model_provider_func, args): def setup_model_and_optimizer(model_provider_func):
"""Setup model and optimizer.""" """Setup model and optimizer."""
args = get_args()
model = get_model(model_provider_func, args) model = get_model(model_provider_func)
optimizer = get_optimizer(model, args) optimizer = get_optimizer(model)
lr_scheduler = get_learning_rate_scheduler(optimizer, args) lr_scheduler = get_learning_rate_scheduler(optimizer)
if args.load is not None: if args.load is not None:
args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args) args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
else: else:
args.iteration = 0 args.iteration = 0
return model, optimizer, lr_scheduler return model, optimizer, lr_scheduler
def backward_step(optimizer, model, loss, args, timers): def backward_step(optimizer, model, loss):
"""Backward step.""" """Backward step."""
args = get_args()
timers = get_timers()
# Backward pass. # Backward pass.
optimizer.zero_grad() optimizer.zero_grad()
...@@ -279,19 +260,21 @@ def backward_step(optimizer, model, loss, args, timers): ...@@ -279,19 +260,21 @@ def backward_step(optimizer, model, loss, args, timers):
optimizer.clip_master_grads(args.clip_grad) optimizer.clip_master_grads(args.clip_grad)
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler, def train_step(forward_step_func, data_iterator,
args, timers): model, optimizer, lr_scheduler):
"""Single training step.""" """Single training step."""
args = get_args()
timers = get_timers()
# Forward model for one step. # Forward model for one step.
timers('forward').start() timers('forward').start()
loss, loss_reduced = forward_step_func(data_iterator, model, args, timers) loss, loss_reduced = forward_step_func(data_iterator, model)
timers('forward').stop() timers('forward').stop()
torch.cuda.synchronize() torch.cuda.synchronize()
# Calculate gradients, reduce across processes, and clip. # Calculate gradients, reduce across processes, and clip.
timers('backward').start() timers('backward').start()
backward_step(optimizer, model, loss, args, timers) backward_step(optimizer, model, loss)
timers('backward').stop() timers('backward').stop()
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -312,7 +295,11 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler, ...@@ -312,7 +295,11 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
def training_log(loss_dict, total_loss_dict, learning_rate, iteration, def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag, writer, args, timers): loss_scale, report_memory_flag):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
writer = get_tensorboard_writer()
# Update losses. # Update losses.
for key in loss_dict: for key in loss_dict:
...@@ -368,8 +355,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -368,8 +355,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
def train(forward_step_func, model, optimizer, lr_scheduler, def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args, writer): train_data_iterator, val_data_iterator):
"""Train the model function.""" """Train the model function."""
args = get_args()
timers = get_timers()
# Turn on training mode which enables dropout. # Turn on training mode which enables dropout.
model.train() model.train()
...@@ -388,8 +377,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -388,8 +377,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, train_data_iterator,
model, model,
optimizer, optimizer,
lr_scheduler, lr_scheduler)
args, timers)
skipped_iters += skipped_iter skipped_iters += skipped_iter
iteration += 1 iteration += 1
...@@ -397,42 +385,41 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -397,42 +385,41 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag = training_log(loss_dict, total_loss_dict, report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, optimizer.loss_scale, iteration, optimizer.loss_scale,
report_memory_flag, writer, args, report_memory_flag)
timers)
# Autoresume # Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and \ if args.adlr_autoresume and \
args.adlr_autoresume: (iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer, check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args) lr_scheduler)
# Checkpointing # Checkpointing
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args) save_checkpoint(iteration, model, optimizer, lr_scheduler)
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \ if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid: args.do_valid:
prefix = 'iteration {}'.format(iteration) prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
val_data_iterator, model, args, val_data_iterator, model,
writer, iteration, timers, False) iteration, False)
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier() torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
print('rank: {} | time: {} | exiting the program at iteration {}'. print_rank_0('rank: {} | time: {} | exiting the program at '
format(rank, time_str, iteration), flush=True) 'iteration {}'.format(rank, time_str, iteration))
exit() sys.exit()
return iteration, skipped_iters return iteration, skipped_iters
def evaluate(forward_step_func, data_iterator, model, def evaluate(forward_step_func, data_iterator, model, verbose=False):
args, timers, verbose=False):
"""Evaluation.""" """Evaluation."""
args = get_args()
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
model.eval() model.eval()
...@@ -447,8 +434,7 @@ def evaluate(forward_step_func, data_iterator, model, ...@@ -447,8 +434,7 @@ def evaluate(forward_step_func, data_iterator, model,
print_rank_0('Evaluating iter {}/{}'.format(iteration, print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters)) args.eval_iters))
# Forward evaluation. # Forward evaluation.
_, loss_dict = forward_step_func(data_iterator, model, _, loss_dict = forward_step_func(data_iterator, model)
args, timers)
# Reduce across processes. # Reduce across processes.
for key in loss_dict: for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + \ total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
...@@ -464,11 +450,11 @@ def evaluate(forward_step_func, data_iterator, model, ...@@ -464,11 +450,11 @@ def evaluate(forward_step_func, data_iterator, model,
def evaluate_and_print_results(prefix, forward_step_func, def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model, data_iterator, model,
args, writer, iteration, iteration, verbose=False):
timers, verbose=False):
"""Helper function to evaluate and dump results on screen.""" """Helper function to evaluate and dump results on screen."""
total_loss_dict = evaluate(forward_step_func, data_iterator, model, writer = get_tensorboard_writer()
args, timers, verbose)
total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
string = ' validation loss at {} | '.format(prefix) string = ' validation loss at {} | '.format(prefix)
for key in total_loss_dict: for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
...@@ -486,23 +472,23 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -486,23 +472,23 @@ def evaluate_and_print_results(prefix, forward_step_func,
print_rank_0('-' * length) print_rank_0('-' * length)
def get_train_val_test_data_iterators(train_data, val_data, test_data, args): def get_train_val_test_data_iterators(train_data, val_data, test_data):
"""Build train/validation/test iterators""" """Build train/validation/test iterators"""
args = get_args()
# If resume is on, shift the start iterations. # Shift the start iterations.
if args.resume_dataloader: if train_data is not None:
if train_data is not None: train_data.batch_sampler.start_iter = args.iteration % \
train_data.batch_sampler.start_iter = args.iteration % \ len(train_data)
len(train_data) print_rank_0('setting training data start iteration to {}'.
print_rank_0('setting training data start iteration to {}'. format(train_data.batch_sampler.start_iter))
format(train_data.batch_sampler.start_iter)) if val_data is not None:
if val_data is not None: start_iter_val = (args.iteration // args.eval_interval) * \
start_iter_val = (args.iteration // args.eval_interval) * \ args.eval_iters
args.eval_iters val_data.batch_sampler.start_iter = start_iter_val % \
val_data.batch_sampler.start_iter = start_iter_val % \ len(val_data)
len(val_data) print_rank_0('setting validation data start iteration to {}'.
print_rank_0('setting validation data start iteration to {}'. format(val_data.batch_sampler.start_iter))
format(val_data.batch_sampler.start_iter))
if train_data is not None: if train_data is not None:
train_data_iterator = iter(train_data) train_data_iterator = iter(train_data)
......
...@@ -13,29 +13,114 @@ ...@@ -13,29 +13,114 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Utilities for logging and serialization""" """General utilities."""
import os import sys
import random
import time
import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam import torch
from megatron import get_args
from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron.fp16 import FP16_Module from megatron import print_rank_0
from megatron.checkpointing import save_checkpoint
from megatron.data.samplers import DistributedBatchSampler
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
def reduce_losses(losses):
"""Reduce a tensor of losses across all GPUs."""
reduced_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(reduced_losses)
reduced_losses = reduced_losses / torch.distributed.get_world_size()
return reduced_losses
def report_memory(name):
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
string = name + ' memory (MB)'
string += ' | allocated: {}'.format(
torch.cuda.memory_allocated() / mega_bytes)
string += ' | max allocated: {}'.format(
torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
string += ' | max cached: {}'.format(
torch.cuda.max_memory_cached()/ mega_bytes)
print_rank_0(string)
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
rank = torch.distributed.get_rank()
string = 'iteration, rank, index, model-parallel,min, max, norm\n'
optimizer_ = optimizer
if isinstance(optimizer, FP16_Optimizer):
optimizer_ = optimizer.optimizer
for param_group in optimizer_.param_groups:
for param in param_group['params']:
index += 1
min_ = param.data.min()
max_ = param.data.max()
norm = param.data.norm()
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True)
def check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler):
"""Check for autoresume signal and exit if it is received."""
args = get_args()
autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy.
torch.distributed.barrier()
if autoresume.termination_requested():
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0:
autoresume.request_resume()
print_rank_0(">>> training terminated. Returning")
sys.exit(0)
def make_data_loader(dataset):
"""Buld dataloader given an input dataset."""
if dataset is None:
return None
args = get_args()
# Data parallel arguments.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
# Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(sampler=sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
def get_ltor_masks_and_position_ids(data, def get_ltor_masks_and_position_ids(data,
eod_token, eod_token,
reset_position_ids, reset_position_ids,
reset_attention_mask, reset_attention_mask,
eod_mask_loss): eod_mask_loss,
fp16):
"""Build masks and position id for left to right model.""" """Build masks and position id for left to right model."""
# Extract batch size and sequence length. # Extract batch size and sequence length.
...@@ -85,468 +170,8 @@ def get_ltor_masks_and_position_ids(data, ...@@ -85,468 +170,8 @@ def get_ltor_masks_and_position_ids(data,
position_ids[b, (i+1):] -= (i + 1 - prev_index) position_ids[b, (i+1):] -= (i + 1 - prev_index)
prev_index = i + 1 prev_index = i + 1
return attention_mask, loss_mask, position_ids # Convert
if fp16:
attention_mask = attention_mask.half()
def reduce_losses(losses):
reduced_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(reduced_losses)
reduced_losses = reduced_losses / torch.distributed.get_world_size()
return reduced_losses
def get_tensorboard_writer(args):
writer = None
if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == 0:
try:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir=args.tensorboard_dir)
except ModuleNotFoundError:
print_rank_0('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.')
writer = None
return writer
def print_rank_0(message):
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def enable_adlr_autoresume(args):
print_rank_0('enabling autoresume ...')
import sys
sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
try:
from userlib.auto_resume import AutoResume
except:
print_rank_0('ADLR autoresume is not available, exiting ...')
exit()
args.AutoResume = AutoResume
args.AutoResume.init()
def check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args):
# Add barrier to ensure consistnecy.
torch.distributed.barrier()
if args.AutoResume.termination_requested():
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0:
args.AutoResume.request_resume()
print_rank_0(">>> training terminated. Returning")
exit(0)
def print_args(args, writer=None):
"""Print arguments."""
print('arguments:', flush=True)
for arg in vars(args):
dots = '.' * (29 - len(arg))
print(' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True)
if writer:
writer.add_text(arg, str(getattr(args, arg)))
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
rank = torch.distributed.get_rank()
string = 'iteration, rank, index, model-parallel,min, max, norm\n'
optimizer_ = optimizer
if isinstance(optimizer, FP16_Optimizer):
optimizer_ = optimizer.optimizer
for param_group in optimizer_.param_groups:
for param in param_group['params']:
index += 1
min_ = param.data.min()
max_ = param.data.max()
norm = param.data.norm()
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True)
class Timers:
"""Group of timers."""
class Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = self.Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '_time', value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0/ normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
print_rank_0(string)
def report_memory(name):
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
string = name + ' memory (MB)'
string += ' | allocated: {}'.format(
torch.cuda.memory_allocated() / mega_bytes)
string += ' | max allocated: {}'.format(
torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
string += ' | max cached: {}'.format(
torch.cuda.max_memory_cached()/ mega_bytes)
print_rank_0(string)
def vocab_size_with_padding(num_tokens, args):
after = num_tokens
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
while (after % multiple) != 0:
after += 1
print_rank_0('> padded vocab (size: {}) with {} dummy '
'tokens (new size: {})'.format(
num_tokens, after - num_tokens, after))
return after
def initialize_distributed(args): return attention_mask, loss_mask, position_ids
"""Initialize torch.distributed."""
# Manually set the device ids.
device = args.rank % torch.cuda.device_count()
if args.local_rank is not None:
device = args.local_rank
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
def set_random_seed(seed):
"""Set random seed for reproducability."""
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
def get_checkpoint_name(checkpoints_path, iteration, release=False,
mp_rank=None):
if release:
d = 'release'
else:
d = 'iter_{:07d}'.format(iteration)
return os.path.join(checkpoints_path, d,
'mp_rank_{:02d}'.format(
mpu.get_model_parallel_rank() if mp_rank is None \
else mp_rank),
'model_optim_rng.pt')
def ensure_directory_exists(filename):
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname)
def get_checkpoint_tracker_filename(checkpoints_path):
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
def save_checkpoint(iteration, model, optimizer,
lr_scheduler, args):
"""Save a model checkpoint."""
# Only rank zer0 of the data parallel writes to the disk.
if isinstance(model, torchDDP):
model = model.module
if mpu.get_data_parallel_rank() == 0:
checkpoint_name = get_checkpoint_name(args.save, iteration)
print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
format(torch.distributed.get_rank(), iteration, checkpoint_name))
sd = {}
sd['iteration'] = iteration
sd['model'] = model.state_dict_for_save_checkpoint()
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
sd['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None:
sd['lr_scheduler'] = lr_scheduler.state_dict()
# rng states.
if not args.no_save_rng:
sd['random_rng_state'] = random.getstate()
sd['np_rng_state'] = np.random.get_state()
sd['torch_rng_state'] = torch.get_rng_state()
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
ensure_directory_exists(checkpoint_name)
torch.save(sd, checkpoint_name)
print(' successfully saved {}'.format(checkpoint_name))
# Wait so everyone is done (necessary)
torch.distributed.barrier()
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler, args):
"""Load a model checkpoint."""
if isinstance(model, torchDDP):
model = model.module
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(args.load)
if not os.path.isfile(tracker_filename):
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return 0
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
# Checkpoint.
checkpoint_name = get_checkpoint_name(args.load, iteration, release)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
# Load the checkpoint.
try:
sd = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
# For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...')
import sys
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16.loss_scaler']
sd = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
except:
print_rank_0('could not load the checkpoint')
exit()
# Iterations.
if args.finetune or release:
iteration = 0
else:
try:
iteration = sd['iteration']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = sd['total_iters']
except KeyError:
print_rank_0('A metadata file exists but Unable to load iteration '
' from checkpoint {}, exiting'.format(checkpoint_name))
exit()
# Model.
try:
model.load_state_dict(sd['model'])
except KeyError:
print_rank_0('A metadata file exists but unable to load model '
'from checkpoint {}, exiting'.format(checkpoint_name))
exit()
# Optimizer.
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(sd['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(sd['lr_scheduler'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer '
'state.'.format(checkpoint_name))
exit()
# rng states.
if not release and not args.finetune and not args.no_load_rng:
try:
random.setstate(sd['random_rng_state'])
np.random.set_state(sd['np_rng_state'])
torch.set_rng_state(sd['torch_rng_state'])
torch.cuda.set_rng_state(sd['cuda_rng_state'])
mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}, exiting.'
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer '
'state.'.format(checkpoint_name))
exit()
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return iteration
def load_weights(src, dst, dst2src=False):
"""
Loads weights from src to dst via in place copy.
src is a huggingface gpt2model, while dst is one of our models.
dst2src=True loads parameters from our models into huggingface's.
^dst2src is still untested
"""
conv_layer = 'Conv1D' in str(type(src))
for n, p in src.named_parameters():
if dst2src:
data = dst._parameters[n].data
load = p.data
else:
data = p.data
load = dst._parameters[n].data
if conv_layer and 'weight' in n:
data = data.t().contiguous()
load.copy_(data)
# dst._parameters[n].data.copy_(data)
def load_mlp(our, oai, dst2src=False):
load_weights(oai.c_fc, our.dense_h_to_4h, dst2src)
load_weights(oai.c_proj, our.dense_4h_to_h, dst2src)
def load_attention(our, oai, dst2src=False):
load_weights(oai.c_attn, our.query_key_value, dst2src)
load_weights(oai.c_proj, our.dense, dst2src)
def load_transformer_layer(our, oai, dst2src=False):
load_weights(oai.ln_1, our.input_layernorm, dst2src)
load_weights(oai.ln_2, our.post_attention_layernorm, dst2src)
load_mlp(our.mlp, oai.mlp, dst2src)
load_attention(our.attention, oai.attn, dst2src)
def move_weights(our, oai, dst2src=False):
"""
Loads weights from `oai` to `our` via in place copy.
`oai` is a huggingface gpt2model, while `our` is one of our models.
dst2src=True loads parameters from our models into huggingface's.
^dst2src=True is still untested
"""
# while isinstance(our, (torchDDP, model.distributed.DistributedDataParallel, FP16_Module)):
# our=our.module
transformer_model = oai.transformer
load_weights(transformer_model.ln_f, our.transformer.final_layernorm, dst2src)
load_weights(transformer_model.wte, our.word_embeddings, dst2src)
load_weights(transformer_model.wpe, our.position_embeddings, dst2src)
for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h):
load_transformer_layer(our_layer, oai_layer, dst2src)
def merge_parallel_state_dicts(state_dicts):
temp_sd = {}
for sd in state_dicts:
for k, v in sd.items():
temp_sd[k].append()
pass
def merge_parallel_checkpoints(checkpoint_dir, model_parallel_size):
pass
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain ALBERT"""
import torch
import torch.nn.functional as F
from megatron import mpu
from megatron.model import BertModel
from megatron.utils import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import run
from megatron.data.albert_dataset import build_train_valid_test_datasets
from megatron.data_utils.samplers import DistributedBatchSampler
def model_provider(args):
"""Build the model."""
print_rank_0('building BERT model ...')
model = BertModel(
num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
add_binary_head=True,
layernorm_epsilon=args.layernorm_epsilon,
num_tokentypes=args.tokentype_size,
parallel_output=True,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32)
return model
def get_batch(data_iterator, timers):
# Items and their type.
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
datatype = torch.int64
# Broadcast data.
timers('data loader').start()
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['text'].long()
types = data_b['types'].long()
sentence_order = data_b['is_random'].long()
loss_mask = data_b['loss_mask'].float()
lm_labels = data_b['labels'].long()
padding_mask = data_b['padding_mask'].long()
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def forward_step(data_iterator, model, args, timers):
"""Forward step."""
# Get the batch.
timers('batch generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
= get_batch(data_iterator, timers)
timers('batch generator').stop()
# Forward model.
lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types)
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(),
sentence_order.view(-1).contiguous(),
ignore_index=-1)
lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
lm_labels.contiguous())
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
loss = lm_loss + sop_loss
reduced_losses = reduce_losses([lm_loss, sop_loss])
return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]}
def get_train_val_test_data(args):
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
(train_data, valid_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0:
print_rank_0('> building train, validation, and test datasets '
'for ALBERT ...')
if args.data_loader is None:
args.data_loader = 'binary'
if args.data_loader != 'binary':
print('Unsupported {} data loader for ALBERT.'.format(
args.data_loader))
exit(1)
if not args.data_path:
print('ALBERT only supports a unified dataset specified '
'with --data-path')
exit(1)
data_parallel_size = mpu.get_data_parallel_world_size()
data_parallel_rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * data_parallel_size
# Number of train/valid/test samples.
train_iters = args.train_iters
eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [args.train_iters * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
assert len(args.data_path) == 1
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
vocab_file=args.vocab,
data_prefix=args.data_path[0],
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=args.skip_mmap_warmup)
print_rank_0("> finished creating ALBERT datasets ...")
def make_data_loader_(dataset):
if not dataset:
return None
# Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(
sampler=sampler,
batch_size=global_batch_size,
drop_last=True,
rank=data_parallel_rank,
world_size=data_parallel_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True)
train_data = make_data_loader_(train_ds)
valid_data = make_data_loader_(valid_ds)
test_data = make_data_loader_(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
num_tokens = vocab_size_with_padding(train_ds.num_tokens(), args)
token_counts = torch.cuda.LongTensor([num_tokens,
2, # hard coded num_type_tokens
int(do_train),
int(do_valid),
int(do_test)])
else:
token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(token_counts,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
args.vocab_size = token_counts[0].item()
args.tokentype_size = token_counts[1].item()
args.do_train = token_counts[2].item()
args.do_valid = token_counts[3].item()
args.do_test = token_counts[4].item()
return train_data, valid_data, test_data
if __name__ == "__main__":
run('Pretrain BERT model', get_train_val_test_data,
model_provider, forward_step)
...@@ -18,81 +18,70 @@ ...@@ -18,81 +18,70 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from configure_data import configure_data from megatron import get_args
from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import print_rank_0
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.model import BertModel from megatron.model import BertModel
from megatron.utils import print_rank_0 from megatron.training import pretrain
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import run
def model_provider(args): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args()
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
model = BertModel( model = BertModel(
num_layers=args.num_layers, num_tokentypes=2,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
add_binary_head=True, add_binary_head=True,
layernorm_epsilon=args.layernorm_epsilon, parallel_output=True)
num_tokentypes=args.tokentype_size,
parallel_output=True,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32)
return model return model
def get_batch(data_iterator, timers): def get_batch(data_iterator):
# Items and their type. # Items and their type.
keys = ['text', 'types', 'is_random', 'mask', 'mask_labels', 'pad_mask'] keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
datatype = torch.int64 datatype = torch.int64
# Broadcast data. # Broadcast data.
timers('data loader').start()
if data_iterator is not None: if data_iterator is not None:
data = next(data_iterator) data = next(data_iterator)
else: else:
data = None data = None
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype) data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
tokens = data_b['text'].long() tokens = data_b['text'].long()
types = data_b['types'].long() types = data_b['types'].long()
next_sentence = data_b['is_random'].long() sentence_order = data_b['is_random'].long()
loss_mask = data_b['mask'].float() loss_mask = data_b['loss_mask'].float()
lm_labels = data_b['mask_labels'].long() lm_labels = data_b['labels'].long()
padding_mask = data_b['pad_mask'].long() padding_mask = data_b['padding_mask'].long()
return tokens, types, next_sentence, loss_mask, lm_labels, padding_mask return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def forward_step(data_iterator, model, args, timers): def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
tokens, types, next_sentence, loss_mask, lm_labels, padding_mask \ tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
= get_batch(data_iterator, timers) = get_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
lm_logits, nsp_logits = model(tokens, 1-padding_mask, tokentype_ids=types) lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types)
nsp_loss = F.cross_entropy(nsp_logits.view(-1, 2).contiguous().float(), sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(),
next_sentence.view(-1).contiguous(), sentence_order.view(-1).contiguous(),
ignore_index=-1) ignore_index=-1)
lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(), lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
...@@ -100,57 +89,77 @@ def forward_step(data_iterator, model, args, timers): ...@@ -100,57 +89,77 @@ def forward_step(data_iterator, model, args, timers):
lm_loss = torch.sum( lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
loss = lm_loss + nsp_loss loss = lm_loss + sop_loss
reduced_losses = reduce_losses([lm_loss, nsp_loss]) reduced_losses = reduce_losses([lm_loss, sop_loss])
return loss, {'lm loss': reduced_losses[0], 'nsp loss': reduced_losses[1]} return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]}
def get_train_val_test_data(args): def get_train_val_test_data():
"""Load the data on rank zero and boradcast number of tokens to all GPUS.""" """Load the data on rank zero and boradcast number of tokens to all GPUS."""
args = get_args()
(train_data, val_data, test_data) = (None, None, None) (train_data, valid_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
if (args.data_loader == 'raw' print_rank_0('> building train, validation, and test datasets '
or args.data_loader == 'lazy' 'for BERT ...')
or args.data_loader == 'tfrecords'):
data_config = configure_data() data_parallel_size = mpu.get_data_parallel_world_size()
ds_type = 'BERT' data_parallel_rank = mpu.get_data_parallel_rank()
data_config.set_defaults(data_set_type=ds_type, transpose=False) global_batch_size = args.batch_size * data_parallel_size
(train_data, val_data, test_data), tokenizer = data_config.apply(args)
num_tokens = vocab_size_with_padding(tokenizer.num_tokens, args) # Number of train/valid/test samples.
# Need to broadcast num_tokens and num_type_tokens. train_iters = args.train_iters
token_counts = torch.cuda.LongTensor([num_tokens, eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
tokenizer.num_type_tokens, test_iters = args.eval_iters
int(args.do_train), train_val_test_num_samples = [train_iters * global_batch_size,
int(args.do_valid), eval_iters * global_batch_size,
int(args.do_test)]) test_iters * global_batch_size]
else: print_rank_0(' > datasets target sizes (minimum size):')
print("Unsupported data loader for BERT.") print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
exit(1) print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
print_rank_0("> finished creating BERT datasets ...")
train_data = make_data_loader(train_ds)
valid_data = make_data_loader(valid_ds)
test_data = make_data_loader(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor(
[int(do_train), int(do_valid), int(do_test)])
else: else:
token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0]) flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens. # Broadcast num tokens.
torch.distributed.broadcast(token_counts, torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(), mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group()) group=mpu.get_model_parallel_group())
num_tokens = token_counts[0].item() args.do_train = flags[0].item()
num_type_tokens = token_counts[1].item() args.do_valid = flags[1].item()
args.do_train = token_counts[2].item() args.do_test = flags[2].item()
args.do_valid = token_counts[3].item()
args.do_test = token_counts[4].item()
args.vocab_size = num_tokens return train_data, valid_data, test_data
args.tokentype_size = num_type_tokens
return train_data, val_data, test_data
if __name__ == "__main__": if __name__ == "__main__":
run('Pretrain BERT model', get_train_val_test_data, pretrain(get_train_val_test_data, model_provider, forward_step,
model_provider, forward_step) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
...@@ -15,55 +15,47 @@ ...@@ -15,55 +15,47 @@
"""Pretrain GPT2""" """Pretrain GPT2"""
import os
import torch import torch
from configure_data import configure_data from megatron import get_args
from gpt2_data_loader import make_gpt2_dataloaders from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron import print_rank_0
from megatron.data.gpt2_dataset import GPT2Dataset
from megatron.model import GPT2Model from megatron.model import GPT2Model
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import print_rank_0 from megatron.utils import make_data_loader
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import run
def model_provider(args): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args()
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
model = GPT2Model(num_layers=args.num_layers, model = GPT2Model(num_tokentypes=0, parallel_output=True)
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
layernorm_epsilon=args.layernorm_epsilon,
parallel_output=True,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32)
return model return model
def get_batch(data_iterator, args, timers): def get_batch(data_iterator):
"""Generate a batch""" """Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type. # Items and their type.
keys = ['text'] keys = ['text']
datatype = torch.int64 datatype = torch.int64
# Broadcast data. # Broadcast data.
timers('data loader').start()
if data_iterator is not None: if data_iterator is not None:
data = next(data_iterator) data = next(data_iterator)
else: else:
data = None data = None
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype) data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
...@@ -74,24 +66,23 @@ def get_batch(data_iterator, args, timers): ...@@ -74,24 +66,23 @@ def get_batch(data_iterator, args, timers):
# Get the masks and postition ids. # Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens, tokens,
args.eod_token, tokenizer.eod,
args.reset_position_ids, args.reset_position_ids,
args.reset_attention_mask, args.reset_attention_mask,
args.eod_mask_loss) args.eod_mask_loss,
# Convert args.fp16)
if args.fp16:
attention_mask = attention_mask.half()
return tokens, labels, loss_mask, attention_mask, position_ids return tokens, labels, loss_mask, attention_mask, position_ids
def forward_step(data_iterator, model, args, timers): def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch( tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator, args, timers) data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
...@@ -107,60 +98,71 @@ def forward_step(data_iterator, model, args, timers): ...@@ -107,60 +98,71 @@ def forward_step(data_iterator, model, args, timers):
return loss, {'lm loss': reduced_loss[0]} return loss, {'lm loss': reduced_loss[0]}
def get_train_val_test_data(args): def make_gpt2_dataloaders():
"""Build gpt2 dataloders."""
args = get_args()
# Input parameters.
input_data_sizes_file = args.input_data_sizes_file
seq_length = args.seq_length
initial_seed = args.seed
# Build the datasets.
def _build_dataset(name):
return GPT2Dataset(os.path.join(args.data_path, name),
args.input_data_sizes_file,
args.seq_length, args.seed)
train_ds = _build_dataset('train')
valid_ds = _build_dataset('valid')
test_ds = _build_dataset('test')
# Dataloaders
train = make_data_loader(train_ds)
valid = make_data_loader(valid_ds)
test = make_data_loader(test_ds)
args.do_train = False
args.do_valid = False
args.do_test = False
if train is not None:
args.do_train = True
if valid is not None:
args.do_valid = True
if test is not None:
args.do_test = True
return (train, valid, test)
def get_train_val_test_data():
"""Load the data on rank zero and boradcast number of tokens to all GPUS.""" """Load the data on rank zero and boradcast number of tokens to all GPUS."""
args = get_args()
(train_data, val_data, test_data) = (None, None, None) (train_data, val_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
if args.data_loader == 'numpy':
assert len(args.train_data) == 1 (train_data, val_data, test_data) = make_gpt2_dataloaders()
args.train_data = args.train_data[0] flags = torch.cuda.LongTensor([int(args.do_train),
assert len(args.valid_data) == 1 int(args.do_valid),
args.valid_data = args.valid_data[0] int(args.do_test)])
assert len(args.test_data) == 1
args.test_data = args.test_data[0]
(train_data, val_data, test_data), num_tokens, \
eod_token = make_gpt2_dataloaders(args)
elif args.data_loader == 'raw' or args.data_loader == 'lazy':
data_config = configure_data()
data_config.set_defaults(data_set_type='GPT2', transpose=False)
(train_data, val_data, test_data), tokenizer = data_config.apply(
args)
num_tokens = tokenizer.num_tokens
eod_token = tokenizer.get_command('eos').Id
assert eod_token == tokenizer.get_command('pad').Id
else:
print("Unsupported data loader for GPT2.")
exit(1)
# pad.
num_tokens = vocab_size_with_padding(num_tokens, args)
print_rank_0('> found end-of-document token: {}'.format(eod_token))
token_counts = torch.cuda.LongTensor([num_tokens, eod_token,
int(args.do_train),
int(args.do_valid),
int(args.do_test)])
else: else:
token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0]) flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens. # Broadcast num tokens.
torch.distributed.broadcast(token_counts, torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(), mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group()) group=mpu.get_model_parallel_group())
num_tokens = token_counts[0].item() args.do_train = flags[0].item()
eod_token = token_counts[1].item() args.do_valid = flags[1].item()
args.do_train = token_counts[2].item() args.do_test = flags[2].item()
args.do_valid = token_counts[3].item()
args.do_test = token_counts[4].item()
args.vocab_size = num_tokens
args.eod_token = eod_token
return train_data, val_data, test_data return train_data, val_data, test_data
if __name__ == "__main__": if __name__ == "__main__":
run('Pretrain GPT-2 model', get_train_val_test_data, pretrain(get_train_val_test_data, model_provider, forward_step,
model_provider, forward_step) args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment