Commit a8d47812 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

checkpoint versioning

parent 7d4ad51e
......@@ -313,8 +313,6 @@ def _add_checkpointing_args(parser):
help='Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. '
'Assumed when loading a release checkpoint.')
group.add_argument('--old-checkpoint-format', action='store_true',
help='load old checkpoint format[Q[]K[]V[]].')
return parser
......
......@@ -27,6 +27,15 @@ from megatron import mpu, get_args
from megatron import get_args
from megatron import print_rank_0
_CHECKPOINT_VERSION = None
def set_checkpoint_version(value):
global _CHECKPOINT_VERSION
_CHECKPOINT_VERSION = value
def get_checkpoint_version():
global _CHECKPOINT_VERSION
return _CHECKPOINT_VERSION
def check_checkpoint_args(checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input
......@@ -90,6 +99,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
state_dict['checkpoint_version'] = 1
state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint()
......@@ -184,6 +194,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
print_rank_0('could not load the checkpoint')
sys.exit()
# set checkpoint version
set_checkpoint_version(state_dict.get('checkpoint_version', 0))
# Set iteration.
if args.finetune or release:
iteration = 0
......@@ -198,6 +211,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
'iteration from checkpoint {}, exiting'.format(
checkpoint_name))
sys.exit()
# Check arguments.
if 'args' in state_dict:
......
......@@ -23,6 +23,7 @@ from megatron import get_args
from megatron import mpu
from megatron.mpu import LayerNorm
from megatron.module import MegatronModule
from megatron.checkpointing import get_checkpoint_version
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu
......@@ -120,7 +121,6 @@ class ParallelSelfAttention(MegatronModule):
super(ParallelSelfAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.old_checkpoint_format = args.old_checkpoint_format
self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
......@@ -178,7 +178,7 @@ class ParallelSelfAttention(MegatronModule):
input_shape = mixed_layer.size();
last_dim = input_shape[-1]
assert last_dim % 3 == 0
assert last_dim % 3 == 0, "expected QKV dimension"
last_dim_split = last_dim // 3
intermediate_shape = input_shape[:-1] +\
......@@ -200,7 +200,9 @@ class ParallelSelfAttention(MegatronModule):
# Attention heads [s, b, hp] --> [s, b, hp * 3]
mixed_x_layer, _ = self.query_key_value(hidden_states)
if self.old_checkpoint_format:
checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None and \
checkpoint_version == 0:
# [s, b, 3 * hp] --> [s, b, hp * 3]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment