Commit 37ae6646 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

support for old checkpoint format

parent 4bf923d5
...@@ -313,6 +313,8 @@ def _add_checkpointing_args(parser): ...@@ -313,6 +313,8 @@ def _add_checkpointing_args(parser):
help='Load model for finetuning. Do not load optimizer ' help='Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. ' 'or rng state from checkpoint and set iteration to 0. '
'Assumed when loading a release checkpoint.') 'Assumed when loading a release checkpoint.')
group.add_argument('--old-checkpoint-format', action='store_true',
help='load old checkpoint format[Q[]K[]V[]].')
return parser return parser
......
...@@ -120,6 +120,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -120,6 +120,7 @@ class ParallelSelfAttention(MegatronModule):
super(ParallelSelfAttention, self).__init__() super(ParallelSelfAttention, self).__init__()
args = get_args() args = get_args()
self.fp16 = args.fp16 self.fp16 = args.fp16
self.old_checkpoint_format = args.old_checkpoint_format
self.attention_mask_func = attention_mask_func self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
...@@ -170,7 +171,23 @@ class ParallelSelfAttention(MegatronModule): ...@@ -170,7 +171,23 @@ class ParallelSelfAttention(MegatronModule):
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def _transpose_last_dim(self, mixed_layer):
"""[s, b, 3 * hp] -->(view) [s, b, 3, hp] -->(tranpose)
[s, b, hp, 3] -->(view) [s, b, 3 * hp] """
input_shape = mixed_layer.size();
last_dim = input_shape[-1]
assert last_dim % 3 == 0
last_dim_split = last_dim // 3
intermediate_shape = input_shape[:-1] +\
(3, last_dim_split)
mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
mixed_layer = mixed_layer.view(*input_shape)
return mixed_layer
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False):
...@@ -182,6 +199,9 @@ class ParallelSelfAttention(MegatronModule): ...@@ -182,6 +199,9 @@ class ParallelSelfAttention(MegatronModule):
# Attention heads [s, b, hp] --> [s, b, 3 * hp] # Attention heads [s, b, hp] --> [s, b, 3 * hp]
mixed_x_layer, _ = self.query_key_value(hidden_states) mixed_x_layer, _ = self.query_key_value(hidden_states)
if self.old_checkpoint_format:
self._transpose_last_dim(mixed_x_layer)
# [s, b, 3 * hp] --> [s, b, np, 3 * hn] # [s, b, 3 * hp] --> [s, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \ new_tensor_shape = mixed_x_layer.size()[:-1] + \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment