Commit f0a445fa authored by mohammad's avatar mohammad
Browse files

added consumed tokens to checkpoints and some refactoring

parent 4311b695
...@@ -72,6 +72,9 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -72,6 +72,9 @@ def parse_args(extra_args_provider=None, defaults={},
print('using {} for parameters ...'.format(args.params_dtype), print('using {} for parameters ...'.format(args.params_dtype),
flush=True) flush=True)
# Consumed tokens.
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
# Set input defaults. # Set input defaults.
for key in defaults: for key in defaults:
......
...@@ -89,7 +89,8 @@ def get_checkpoint_tracker_filename(checkpoints_path): ...@@ -89,7 +89,8 @@ def get_checkpoint_tracker_filename(checkpoints_path):
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
def save_checkpoint(iteration, model, optimizer, lr_scheduler): def save_checkpoint(iteration, model, optimizer, lr_scheduler,
consumed_train_samples=None, consumed_valid_samples=None):
"""Save a model checkpoint.""" """Save a model checkpoint."""
args = get_args() args = get_args()
...@@ -103,6 +104,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -103,6 +104,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict['args'] = args state_dict['args'] = args
state_dict['checkpoint_version'] = 2.0 state_dict['checkpoint_version'] = 2.0
state_dict['iteration'] = iteration state_dict['iteration'] = iteration
if consumed_train_samples:
state_dict['consumed_train_samples'] = consumed_train_samples
if consumed_valid_samples:
state_dict['consumed_valid_samples'] = consumed_valid_samples
state_dict['model'] = model.state_dict_for_save_checkpoint() state_dict['model'] = model.state_dict_for_save_checkpoint()
# Optimizer stuff. # Optimizer stuff.
...@@ -213,7 +218,13 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -213,7 +218,13 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
'iteration from checkpoint {}, exiting'.format( 'iteration from checkpoint {}, exiting'.format(
checkpoint_name)) checkpoint_name))
sys.exit() sys.exit()
if 'consumed_train_samples' in state_dict:
assert args.consumed_train_samples == 0
args.consumed_train_samples = state_dict['consumed_train_samples']
if 'consumed_valid_samples' in state_dict:
assert args.consumed_valid_samples == 0
args.consumed_valid_samples = state_dict['consumed_valid_samples']
# Check arguments. # Check arguments.
if 'args' in state_dict: if 'args' in state_dict:
......
...@@ -13,7 +13,38 @@ ...@@ -13,7 +13,38 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Megatorn Sampler.""" """Dataloaders."""
import torch
from megatron import get_args
from megatron import mpu
def build_pretraining_data_loader(dataset, consumed_samples):
"""Buld dataloader given an input dataset."""
if dataset is None:
return None
args = get_args()
world_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * world_size
# Megatron sampler
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
global_batch_size=global_batch_size,
rank=mpu.get_data_parallel_rank(),
world_size=world_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True)
class MegatronPretrainingSampler: class MegatronPretrainingSampler:
......
...@@ -37,7 +37,7 @@ from megatron.model import DistributedDataParallel as LocalDDP ...@@ -37,7 +37,7 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
from megatron.model.realm_model import ICTBertModel from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader from megatron.data.data_loaders import build_pretraining_data_loader
from megatron.utils import report_memory from megatron.utils import report_memory
...@@ -104,7 +104,9 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -104,7 +104,9 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
iteration, False) iteration, False)
if args.save and iteration != 0: if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler,
consumed_train_samples=args.consumed_train_samples,
consumed_valid_samples=args.consumed_valid_samples)
if args.do_test: if args.do_test:
# Run on test data. # Run on test data.
...@@ -224,7 +226,8 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -224,7 +226,8 @@ def setup_model_and_optimizer(model_provider_func):
while hasattr(unwrapped_model, 'module'): while hasattr(unwrapped_model, 'module'):
unwrapped_model = unwrapped_model.module unwrapped_model = unwrapped_model.module
if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'): if args.iteration == 0 and hasattr(unwrapped_model,
'init_state_dict_from_bert'):
print("Initializing ICT from pretrained BERT model", flush=True) print("Initializing ICT from pretrained BERT model", flush=True)
unwrapped_model.init_state_dict_from_bert() unwrapped_model.init_state_dict_from_bert()
...@@ -414,6 +417,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -414,6 +417,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
optimizer, optimizer,
lr_scheduler) lr_scheduler)
iteration += 1 iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.batch_size
# Logging. # Logging.
loss_scale = None loss_scale = None
...@@ -433,7 +438,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -433,7 +438,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Checkpointing # Checkpointing
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler,
consumed_train_samples=args.consumed_train_samples,
consumed_valid_samples=args.consumed_valid_samples)
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \ if args.eval_interval and iteration % args.eval_interval == 0 and \
...@@ -472,6 +479,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -472,6 +479,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args.eval_iters)) args.eval_iters))
# Forward evaluation. # Forward evaluation.
_, loss_dict = forward_step_func(data_iterator, model) _, loss_dict = forward_step_func(data_iterator, model)
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.batch_size
# Reduce across processes. # Reduce across processes.
for key in loss_dict: for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + \ total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
...@@ -517,11 +526,19 @@ def build_train_valid_test_data_iterators( ...@@ -517,11 +526,19 @@ def build_train_valid_test_data_iterators(
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
print_rank_0('> building train, validation, and test datasets ...') print_rank_0('> building train, validation, and test datasets ...')
# Rank and global batch size.
data_parallel_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * data_parallel_size
# Backward compatibility, assume fixed batch size.
if args.iteration > 0 and args.consumed_train_samples == 0:
args.consumed_train_samples = args.iteration * global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0:
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * global_batch_size
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
# Rank, size, and global batch size.
data_parallel_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * data_parallel_size
# Number of train/valid/test samples. # Number of train/valid/test samples.
train_iters = args.train_iters train_iters = args.train_iters
...@@ -540,12 +557,11 @@ def build_train_valid_test_data_iterators( ...@@ -540,12 +557,11 @@ def build_train_valid_test_data_iterators(
train_val_test_num_samples) train_val_test_num_samples)
# Build dataloders. # Build dataloders.
comsumed_samples = args.iteration * global_batch_size train_dataloader = build_pretraining_data_loader(
train_dataloader = make_data_loader(train_ds, comsumed_samples) train_ds, args.consumed_train_samples)
comsumed_samples = (args.iteration // args.eval_interval) * \ valid_dataloader = build_pretraining_data_loader(
args.eval_iters * global_batch_size valid_ds, args.consumed_valid_samples)
valid_dataloader = make_data_loader(valid_ds, comsumed_samples) test_dataloader = build_pretraining_data_loader(test_ds, 0)
test_dataloader = make_data_loader(test_ds, comsumed_samples)
# Flags to know if we need to do training/validation/testing. # Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and args.train_iters > 0 do_train = train_dataloader is not None and args.train_iters > 0
......
...@@ -24,7 +24,6 @@ from megatron import print_rank_0 ...@@ -24,7 +24,6 @@ from megatron import print_rank_0
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.data.sampler import MegatronPretrainingSampler
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
...@@ -89,30 +88,6 @@ def check_adlr_autoresume_termination(iteration, model, ...@@ -89,30 +88,6 @@ def check_adlr_autoresume_termination(iteration, model,
sys.exit(0) sys.exit(0)
def make_data_loader(dataset, consumed_samples):
"""Buld dataloader given an input dataset."""
if dataset is None:
return None
args = get_args()
# Data parallel arguments.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
# Megatron sampler
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset), consumed_samples=consumed_samples,
global_batch_size=global_batch_size, rank=rank, world_size=world_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
def get_ltor_masks_and_position_ids(data, def get_ltor_masks_and_position_ids(data,
eod_token, eod_token,
reset_position_ids, reset_position_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment